Merge branch 'master' into refine_offload

This commit is contained in:
Yao Chi 2025-11-26 16:58:34 +08:00 committed by GitHub
commit d28093f290
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
125 changed files with 11457 additions and 8155 deletions

View File

@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause

View File

@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause pause

View File

@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation .\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause pause

View File

@ -8,13 +8,15 @@ body:
Before submitting a **Bug Report**, please ensure the following: Before submitting a **Bug Report**, please ensure the following:
- **1:** You are running the latest version of ComfyUI. - **1:** You are running the latest version of ComfyUI.
- **2:** You have looked at the existing bug reports and made sure this isn't already reported. - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
`--disable-all-custom-nodes` command line argument. `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. ## Very Important
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
- type: checkboxes - type: checkboxes
id: custom-nodes-test id: custom-nodes-test
attributes: attributes:

View File

@ -0,0 +1,21 @@
<!-- API_NODE_PR_CHECKLIST: do not remove -->
## API Node PR Checklist
### Scope
- [ ] **Is API Node Change**
### Pricing & Billing
- [ ] **Need pricing update**
- [ ] **No pricing update**
If **Need pricing update**:
- [ ] Metronome rate cards updated
- [ ] Autobilling tests updated and passing
### QA
- [ ] **QA done**
- [ ] **QA not required**
### Comms
- [ ] Informed **Kosinkadink**

58
.github/workflows/api-node-template.yml vendored Normal file
View File

@ -0,0 +1,58 @@
name: Append API Node PR template
on:
pull_request_target:
types: [opened, reopened, synchronize, ready_for_review]
paths:
- 'comfy_api_nodes/**' # only run if these files changed
permissions:
contents: read
pull-requests: write
jobs:
inject:
runs-on: ubuntu-latest
steps:
- name: Ensure template exists and append to PR body
uses: actions/github-script@v7
with:
script: |
const { owner, repo } = context.repo;
const number = context.payload.pull_request.number;
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
let templateText;
try {
const res = await github.rest.repos.getContent({
owner,
repo,
path: templatePath,
ref: pr.base.ref
});
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
templateText = buf.toString('utf8');
} catch (e) {
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
return;
}
// Enforce the presence of the marker inside the template (for idempotence)
if (!templateText.includes(marker)) {
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
return;
}
// If the PR already contains the marker, do not append again.
const body = pr.body || '';
if (body.includes(marker)) {
core.info('Template already present in PR body; nothing to inject.');
return;
}
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
core.notice('API Node template appended to PR description.');

View File

@ -14,13 +14,13 @@ jobs:
contents: "write" contents: "write"
packages: "write" packages: "write"
pull-requests: "read" pull-requests: "read"
name: "Release NVIDIA Default (cu129)" name: "Release NVIDIA Default (cu130)"
uses: ./.github/workflows/stable-release.yml uses: ./.github/workflows/stable-release.yml
with: with:
git_tag: ${{ inputs.git_tag }} git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129" cache_tag: "cu130"
python_minor: "13" python_minor: "13"
python_patch: "6" python_patch: "9"
rel_name: "nvidia" rel_name: "nvidia"
rel_extra_name: "" rel_extra_name: ""
test_release: true test_release: true
@ -43,6 +43,23 @@ jobs:
test_release: true test_release: true
secrets: inherit secrets: inherit
release_nvidia_cu126:
permissions:
contents: "write"
packages: "write"
pull-requests: "read"
name: "Release NVIDIA cu126"
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu126"
python_minor: "12"
python_patch: "10"
rel_name: "nvidia"
rel_extra_name: "_cu126"
test_release: true
secrets: inherit
release_amd_rocm: release_amd_rocm:
permissions: permissions:
contents: "write" contents: "write"

View File

@ -21,14 +21,15 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
# os: [macos, linux, windows] # os: [macos, linux, windows]
os: [macos, linux] # os: [macos, linux]
python_version: ["3.9", "3.10", "3.11", "3.12"] os: [linux]
python_version: ["3.10", "3.11", "3.12"]
cuda_version: ["12.1"] cuda_version: ["12.1"]
torch_version: ["stable"] torch_version: ["stable"]
include: include:
- os: macos # - os: macos
runner_label: [self-hosted, macOS] # runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention" # flags: "--use-pytorch-cross-attention"
- os: linux - os: linux
runner_label: [self-hosted, Linux] runner_label: [self-hosted, Linux]
flags: "" flags: ""
@ -73,14 +74,15 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
os: [macos, linux] # os: [macos, linux]
os: [linux]
python_version: ["3.11"] python_version: ["3.11"]
cuda_version: ["12.1"] cuda_version: ["12.1"]
torch_version: ["nightly"] torch_version: ["nightly"]
include: include:
- os: macos # - os: macos
runner_label: [self-hosted, macOS] # runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention" # flags: "--use-pytorch-cross-attention"
- os: linux - os: linux
runner_label: [self-hosted, Linux] runner_label: [self-hosted, Linux]
flags: "" flags: ""

View File

@ -17,7 +17,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "129" default: "130"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@ -29,7 +29,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "6" default: "9"
# push: # push:
# branches: # branches:
# - master # - master

168
QUANTIZATION.md Normal file
View File

@ -0,0 +1,168 @@
# The Comfy guide to Quantization
## How does quantization work?
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
```
absmax = max(abs(tensor))
scale = amax / max_dynamic_range_low_precision
# Quantization
tensor_q = (tensor / scale).to(low_precision_dtype)
# De-Quantization
tensor_dq = tensor_q.to(fp16) * scale
tensor_dq ~ tensor
```
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
## Quantization in Comfy
```
QuantizedTensor (torch.Tensor subclass)
↓ __torch_dispatch__
Two-Level Registry (generic + layout handlers)
MixedPrecisionOps + Metadata Detection
```
### Representation
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
A `Layout` class defines how a specific quantization format behaves:
- Required parameters
- Quantize method
- De-Quantize method
```python
from comfy.quant_ops import QuantizedLayout
class MyLayout(QuantizedLayout):
@classmethod
def quantize(cls, tensor, **kwargs):
# Convert to quantized format
qdata = ...
params = {'scale': ..., 'orig_dtype': tensor.dtype}
return qdata, params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
return qdata.to(orig_dtype) * scale
```
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
```python
from comfy.quant_ops import register_layout_op
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
def my_linear(func, args, kwargs):
# Extract tensors, call optimized kernel
...
```
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
### Mixed Precision
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
**Architecture:**
```python
class MixedPrecisionOps(disable_weight_init):
_layer_quant_config = {} # Maps layer names to quantization configs
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
```
**Key mechanism:**
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
- If the layer name **is** in `_layer_quant_config`:
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
- Load associated quantization parameters (scales, block_size, etc.)
**Why it's needed:**
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
## Checkpoint Format
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
The quantized checkpoint will contain the same layers as the original checkpoint but:
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
### Scaling Parameters details
We define 4 possible scaling parameters that should cover most recipes in the near-future:
- **weight_scale**: quantization scalers for the weights
- **weight_scale_2**: global scalers in the context of double scaling
- **pre_quant_scale**: scalers used for smoothing salient weights
- **input_scale**: quantization scalers for the activations
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|--------|---------------|--------------|----------------|-----------------|-------------|
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
### Quantization Metadata
The metadata stored alongside the checkpoint contains:
- **format_version**: String to define a version of the standard
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
Example:
```json
{
"_quantization_metadata": {
"format_version": "1.0",
"layers": {
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
}
}
}
```
## Creating Quantized Checkpoints
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
### Weight Quantization
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
### Calibration (for Activation Quantization)
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
1. **Collect statistics**: Run inference on N representative samples
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
3. **Compute scales**: Derive `input_scale` from collected statistics
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.

View File

@ -67,6 +67,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/) - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/) - [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
@ -112,10 +113,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
## Release Process ## Release Process
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories: ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)** 1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
- Releases a new stable version (e.g., v0.7.0) - Releases a new stable version (e.g., v0.7.0) roughly every week.
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
- Serves as the foundation for the desktop release - Serves as the foundation for the desktop release
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)** 2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
@ -172,15 +174,19 @@ There is a portable standalone build for Windows that should work for running on
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
If you have trouble extracting it, right click the file -> properties -> unblock If you have trouble extracting it, right click the file -> properties -> unblock
Update your Nvidia drivers if it doesn't start.
#### Alternative Downloads: #### Alternative Downloads:
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z) [Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs). [Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
#### How do I share models between another UI and ComfyUI? #### How do I share models between another UI and ComfyUI?
@ -197,10 +203,12 @@ comfy install
## Manual Install (Windows, Linux) ## Manual Install (Windows, Linux)
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended. Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12 Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
### Instructions:
Git clone this repo. Git clone this repo.
Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints Put your SD checkpoints (the huge ckpt/safetensors files) in: models/checkpoints
@ -216,7 +224,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements: This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only. ### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
@ -237,7 +245,7 @@ RDNA 4 (RX 9000 series):
### Intel GPUs (Windows and Linux) ### Intel GPUs (Windows and Linux)
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
1. To install PyTorch xpu, use the following command: 1. To install PyTorch xpu, use the following command:
@ -247,10 +255,6 @@ This is the command to install the Pytorch xpu nightly which might have some per
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu``` ```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
### NVIDIA ### NVIDIA
Nvidia users should install stable pytorch using this command: Nvidia users should install stable pytorch using this command:

View File

@ -10,7 +10,8 @@ import importlib
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import TypedDict, Optional from typing import Dict, TypedDict, Optional
from aiohttp import web
from importlib.metadata import version from importlib.metadata import version
import requests import requests
@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
sys.exit(-1) sys.exit(-1)
@classmethod @classmethod
def templates_path(cls) -> str: def template_asset_map(cls) -> Optional[Dict[str, str]]:
"""Return a mapping of template asset names to their absolute paths."""
try:
from comfyui_workflow_templates import (
get_asset_path,
iter_templates,
)
except ImportError:
logging.error(
f"""
********** ERROR ***********
comfyui-workflow-templates is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
return None
try:
template_entries = list(iter_templates())
except Exception as exc:
logging.error(f"Failed to enumerate workflow templates: {exc}")
return None
asset_map: Dict[str, str] = {}
try:
for entry in template_entries:
for asset in entry.assets:
asset_map[asset.filename] = get_asset_path(
entry.template_id, asset.filename
)
except Exception as exc:
logging.error(f"Failed to resolve template asset paths: {exc}")
return None
if not asset_map:
logging.error("No workflow template assets found. Did the packages install correctly?")
return None
return asset_map
@classmethod
def legacy_templates_path(cls) -> Optional[str]:
"""Return the legacy templates directory shipped inside the meta package."""
try: try:
import comfyui_workflow_templates import comfyui_workflow_templates
@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
********** ERROR *********** ********** ERROR ***********
""".strip() """.strip()
) )
return None
@classmethod @classmethod
def embedded_docs_path(cls) -> str: def embedded_docs_path(cls) -> str:
@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
logging.info("Falling back to the default frontend.") logging.info("Falling back to the default frontend.")
check_frontend_version() check_frontend_version()
return cls.default_frontend_path() return cls.default_frontend_path()
@classmethod
def template_asset_handler(cls):
assets = cls.template_asset_map()
if not assets:
return None
async def serve_template(request: web.Request) -> web.StreamResponse:
rel_path = request.match_info.get("path", "")
target = assets.get(rel_path)
if target is None:
raise web.HTTPNotFound()
return web.FileResponse(target)
return serve_template

112
app/subgraph_manager.py Normal file
View File

@ -0,0 +1,112 @@
from __future__ import annotations
from typing import TypedDict
import os
import folder_paths
import glob
from aiohttp import web
import hashlib
class Source:
custom_node = "custom_node"
class SubgraphEntry(TypedDict):
source: str
"""
Source of subgraph - custom_nodes vs templates.
"""
path: str
"""
Relative path of the subgraph file.
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
"""
name: str
"""
Name of subgraph file.
"""
info: CustomNodeSubgraphEntryInfo
"""
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
"""
data: str
class CustomNodeSubgraphEntryInfo(TypedDict):
node_pack: str
"""Node pack name."""
class SubgraphManager:
def __init__(self):
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
async def load_entry_data(self, entry: SubgraphEntry):
with open(entry['path'], 'r') as f:
entry['data'] = f.read()
return entry
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
if entry is None:
return None
entry = entry.copy()
entry.pop('path', None)
if remove_data:
entry.pop('data', None)
return entry
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
entries = entries.copy()
for key in list(entries.keys()):
entries[key] = await self.sanitize_entry(entries[key], remove_data)
return entries
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
# if not forced to reload and cached, return cache
if not force_reload and self.cached_custom_node_subgraphs is not None:
return self.cached_custom_node_subgraphs
# Load subgraphs from custom nodes
subfolder = "subgraphs"
subgraphs_dict: dict[SubgraphEntry] = {}
for folder in folder_paths.get_folder_paths("custom_nodes"):
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
matched_files = glob.glob(pattern)
for file in matched_files:
# replace backslashes with forward slashes
file = file.replace('\\', '/')
info: CustomNodeSubgraphEntryInfo = {
"node_pack": "custom_nodes." + file.split('/')[-3]
}
source = Source.custom_node
# hash source + path to make sure id will be as unique as possible, but
# reproducible across backend reloads
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
entry: SubgraphEntry = {
"source": Source.custom_node,
"name": os.path.splitext(os.path.basename(file))[0],
"path": file,
"info": info,
}
subgraphs_dict[id] = entry
self.cached_custom_node_subgraphs = subgraphs_dict
return subgraphs_dict
async def get_custom_node_subgraph(self, id: str, loadedModules):
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
entry: SubgraphEntry = subgraphs.get(id, None)
if entry is not None and entry.get('data', None) is None:
await self.load_entry_data(entry)
return entry
def add_routes(self, routes, loadedModules):
@routes.get("/global_subgraphs")
async def get_global_subgraphs(request):
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
# that's the reasoning for the current implementation
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
@routes.get("/global_subgraphs/{id}")
async def get_global_subgraph(request):
id = request.match_info.get("id", None)
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
return web.json_response(await self.sanitize_entry(subgraph))

View File

@ -413,7 +413,8 @@ class ControlNet(nn.Module):
out_middle = [] out_middle = []
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] if y is None:
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)
h = x h = x

View File

@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -145,7 +146,9 @@ class PerformanceFeature(enum.Enum):
CublasOps = "cublas_ops" CublasOps = "cublas_ops"
AutoTune = "autotune" AutoTune = "autotune"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
@ -157,7 +160,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.") parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")

View File

@ -310,11 +310,13 @@ class ControlLoraOps:
self.bias = None self.bias = None
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else: else:
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__( def __init__(
@ -350,12 +352,13 @@ class ControlLoraOps:
def forward(self, input): def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input) weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None: if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else: else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options

View File

@ -178,6 +178,15 @@ class Flux(SD3):
def process_out(self, latent): def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor return (latent / self.scale_factor) + self.shift_factor
class Flux2(LatentFormat):
latent_channels = 128
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat): class Mochi(LatentFormat):
latent_channels = 12 latent_channels = 12
latent_dimensions = 3 latent_dimensions = 3
@ -611,6 +620,66 @@ class HunyuanImage21Refiner(LatentFormat):
latent_dimensions = 3 latent_dimensions = 3
scale_factor = 1.03682 scale_factor = 1.03682
def process_in(self, latent):
out = latent * self.scale_factor
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out
def process_out(self, latent):
z = latent / self.scale_factor
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
return z
class HunyuanVideo15(LatentFormat):
latent_rgb_factors = [
[ 0.0568, -0.0521, -0.0131],
[ 0.0014, 0.0735, 0.0326],
[ 0.0186, 0.0531, -0.0138],
[-0.0031, 0.0051, 0.0288],
[ 0.0110, 0.0556, 0.0432],
[-0.0041, -0.0023, -0.0485],
[ 0.0530, 0.0413, 0.0253],
[ 0.0283, 0.0251, 0.0339],
[ 0.0277, -0.0372, -0.0093],
[ 0.0393, 0.0944, 0.1131],
[ 0.0020, 0.0251, 0.0037],
[-0.0017, 0.0012, 0.0234],
[ 0.0468, 0.0436, 0.0203],
[ 0.0354, 0.0439, -0.0233],
[ 0.0090, 0.0123, 0.0346],
[ 0.0382, 0.0029, 0.0217],
[ 0.0261, -0.0300, 0.0030],
[-0.0088, -0.0220, -0.0283],
[-0.0272, -0.0121, -0.0363],
[-0.0664, -0.0622, 0.0144],
[ 0.0414, 0.0479, 0.0529],
[ 0.0355, 0.0612, -0.0247],
[ 0.0147, 0.0264, 0.0174],
[ 0.0438, 0.0038, 0.0542],
[ 0.0431, -0.0573, -0.0033],
[-0.0162, -0.0211, -0.0406],
[-0.0487, -0.0295, -0.0393],
[ 0.0005, -0.0109, 0.0253],
[ 0.0296, 0.0591, 0.0353],
[ 0.0119, 0.0181, -0.0306],
[-0.0085, -0.0362, 0.0229],
[ 0.0005, -0.0106, 0.0242]
]
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
latent_channels = 32
latent_dimensions = 3
scale_factor = 1.03682
class Hunyuan3Dv2(LatentFormat): class Hunyuan3Dv2(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -1,15 +1,15 @@
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import ( from comfy.ldm.flux.layers import (
MLPEmbedder, MLPEmbedder,
RMSNorm, RMSNorm,
QKNorm,
SelfAttention,
ModulationOut, ModulationOut,
) )
# TODO: remove this in a few months
SingleStreamBlock = None
DoubleStreamBlock = None
class ChromaModulationOut(ModulationOut): class ChromaModulationOut(ModulationOut):
@ -48,124 +48,6 @@ class Approximator(nn.Module):
return x return x
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
mod = vec
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class LastLayer(nn.Module): class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()

View File

@ -11,12 +11,12 @@ import comfy.ldm.common_dit
from comfy.ldm.flux.layers import ( from comfy.ldm.flux.layers import (
EmbedND, EmbedND,
timestep_embedding, timestep_embedding,
DoubleStreamBlock,
SingleStreamBlock,
) )
from .layers import ( from .layers import (
DoubleStreamBlock,
LastLayer, LastLayer,
SingleStreamBlock,
Approximator, Approximator,
ChromaModulationOut, ChromaModulationOut,
) )
@ -90,6 +90,7 @@ class Chroma(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -98,7 +99,7 @@ class Chroma(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
@ -178,7 +179,10 @@ class Chroma(nn.Module):
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if i not in self.skip_mmdit: if i not in self.skip_mmdit:
double_mod = ( double_mod = (
self.get_modulations(mod_vectors, "double_img", idx=i), self.get_modulations(mod_vectors, "double_img", idx=i),
@ -221,7 +225,10 @@ class Chroma(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:
single_mod = self.get_modulations(mod_vectors, "single", idx=i) single_mod = self.get_modulations(mod_vectors, "single", idx=i)
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -10,12 +10,10 @@ from torch import Tensor, nn
from einops import repeat from einops import repeat
import comfy.ldm.common_dit import comfy.ldm.common_dit
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
from comfy.ldm.chroma.model import Chroma, ChromaParams from comfy.ldm.chroma.model import Chroma, ChromaParams
from comfy.ldm.chroma.layers import ( from comfy.ldm.chroma.layers import (
DoubleStreamBlock,
SingleStreamBlock,
Approximator, Approximator,
) )
from .layers import ( from .layers import (
@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
DoubleStreamBlock( DoubleStreamBlock(
@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
modulation=False,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
self.hidden_size, self.hidden_size,
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
modulation=False,
dtype=dtype, device=device, operations=operations, dtype=dtype, device=device, operations=operations,
) )
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)

View File

@ -48,11 +48,11 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
self.silu = nn.SiLU() self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x))) return self.out_layer(self.silu(self.in_layer(x)))
@ -80,14 +80,14 @@ class QKNorm(torch.nn.Module):
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
@dataclass @dataclass
@ -98,11 +98,11 @@ class ModulationOut:
class Modulation(nn.Module): class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.is_double = double self.is_double = double
self.multiplier = 6 if double else 3 self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple: def forward(self, vec: Tensor) -> tuple:
if vec.ndim == 2: if vec.ndim == 2:
@ -129,77 +129,129 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor return tensor
class SiLUActivation(nn.Module):
def __init__(self):
super().__init__()
self.gate_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
x1, x2 = x.chunk(2, dim=-1)
return self.gate_fn(x1) * x2
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads self.num_heads = num_heads
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.modulation = modulation
if self.modulation:
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) if mlp_silu_act:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), if mlp_silu_act:
nn.GELU(approximate="tanh"), self.txt_mlp = nn.Sequential(
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
) SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
img_mod1, img_mod2 = self.img_mod(vec) if self.modulation:
txt_mod1, txt_mod2 = self.txt_mod(vec) img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
del img_modulated
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del img_qkv
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
if self.flipped_img_txt: if self.flipped_img_txt:
q = torch.cat((img_q, txt_q), dim=2)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=2)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=2)
del img_v, txt_v
# run actual attention # run actual attention
attn = attention(torch.cat((img_q, txt_q), dim=2), attn = attention(q, k, v,
torch.cat((img_k, txt_k), dim=2),
torch.cat((img_v, txt_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options) pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
else: else:
q = torch.cat((txt_q, img_q), dim=2)
del txt_q, img_q
k = torch.cat((txt_k, img_k), dim=2)
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
# run actual attention # run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2), attn = attention(q, k, v,
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask, transformer_options=transformer_options) pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) del img_attn
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks # calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
del txt_attn
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16: if txt.dtype == torch.float16:
@ -220,6 +272,9 @@ class SingleStreamBlock(nn.Module):
num_heads: int, num_heads: int,
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
qk_scale: float = None, qk_scale: float = None,
modulation=True,
mlp_silu_act=False,
bias=True,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None
@ -231,30 +286,47 @@ class SingleStreamBlock(nn.Module):
self.scale = qk_scale or head_dim**-0.5 self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
# qkv and mlp_in # qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out # proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh") if modulation:
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
else:
self.modulation = None
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
mod, _ = self.modulation(vec) if self.modulation:
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) mod, _ = self.modulation(vec)
else:
mod = vec
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
del qkv
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims) x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16: if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
@ -262,11 +334,11 @@ class SingleStreamBlock(nn.Module):
class LastLayer(nn.Module): class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
if vec.ndim == 2: if vec.ndim == 2:

View File

@ -7,15 +7,8 @@ import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape
if pe is not None: if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) q, k = apply_rope(q, k, pe)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1] heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x return x

View File

@ -15,6 +15,7 @@ from .layers import (
MLPEmbedder, MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
Modulation
) )
@dataclass @dataclass
@ -33,6 +34,11 @@ class FluxParams:
patch_size: int patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
class Flux(nn.Module): class Flux(nn.Module):
@ -58,13 +64,17 @@ class Flux(nn.Module):
self.hidden_size = params.hidden_size self.hidden_size = params.hidden_size
self.num_heads = params.num_heads self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) if params.vec_in_dim is not None:
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.vector_in = None
self.guidance_in = ( self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
) )
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
@ -73,6 +83,9 @@ class Flux(nn.Module):
self.num_heads, self.num_heads,
mlp_ratio=params.mlp_ratio, mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias, qkv_bias=params.qkv_bias,
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -81,13 +94,30 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
if final_layer: if final_layer:
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
if params.global_modulation:
self.double_stream_modulation_img = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.double_stream_modulation_txt = Modulation(
self.hidden_size,
double=True,
bias=False,
dtype=dtype, device=device, operations=operations
)
self.single_stream_modulation = Modulation(
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
)
def forward_orig( def forward_orig(
self, self,
@ -103,9 +133,6 @@ class Flux(nn.Module):
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -118,9 +145,17 @@ class Flux(nn.Module):
if guidance is not None: if guidance is not None:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) if self.vector_in is not None:
if y is None:
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
@ -177,6 +212,9 @@ class Flux(nn.Module):
img = torch.cat((txt, img), 1) img = torch.cat((txt, img), 1)
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
@ -207,10 +245,10 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
bs, c, h, w = x.shape bs, c, h, w = x.shape
patch_size = self.patch_size patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size)) x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
@ -222,10 +260,22 @@ class Flux(nn.Module):
h_offset = ((h_offset + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) steps_h = h_len
steps_w = w_len
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
index += rope_options.get("shift_t", 0.0)
h_offset += rope_options.get("shift_y", 0.0)
w_offset += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
img_ids[:, :, 0] = img_ids[:, :, 1] + index img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
@ -241,16 +291,16 @@ class Flux(nn.Module):
h_len = ((h_orig + (patch_size // 2)) // patch_size) h_len = ((h_orig + (patch_size // 2)) // patch_size)
w_len = ((w_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x) img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1] img_tokens = img.shape[1]
if ref_latents is not None: if ref_latents is not None:
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", "offset") ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method == "index":
index += 1 index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
elif ref_latents_method == "uxo": elif ref_latents_method == "uxo":
@ -274,7 +324,11 @@ class Flux(nn.Module):
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
import comfy.ldm.modules.diffusionmodules.mmdit import comfy.ldm.modules.diffusionmodules.mmdit
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from dataclasses import dataclass from dataclasses import dataclass
from einops import repeat from einops import repeat
@ -42,6 +41,8 @@ class HunyuanVideoParams:
guidance_embed: bool guidance_embed: bool
byt5: bool byt5: bool
meanflow: bool meanflow: bool
use_cond_type_embedding: bool
vision_in_dim: int
class SelfAttentionRef(nn.Module): class SelfAttentionRef(nn.Module):
@ -157,7 +158,10 @@ class TokenRefiner(nn.Module):
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
# m = mask.float().unsqueeze(-1) # m = mask.float().unsqueeze(-1)
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
c = x.sum(dim=1) / x.shape[1] if x.dtype == torch.float16:
c = x.float().sum(dim=1) / x.shape[1]
else:
c = x.sum(dim=1) / x.shape[1]
c = t + self.c_embedder(c.to(x.dtype)) c = t + self.c_embedder(c.to(x.dtype))
x = self.input_embedder(x) x = self.input_embedder(x)
@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
params = HunyuanVideoParams(**kwargs) params = HunyuanVideoParams(**kwargs)
self.params = params self.params = params
self.patch_size = params.patch_size self.patch_size = params.patch_size
self.in_channels = params.in_channels self.in_channels = params.in_channels
self.out_channels = params.out_channels self.out_channels = params.out_channels
self.use_cond_type_embedding = params.use_cond_type_embedding
self.vision_in_dim = params.vision_in_dim
if params.hidden_size % params.num_heads != 0: if params.hidden_size % params.num_heads != 0:
raise ValueError( raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module):
if final_layer: if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
# HunyuanVideo 1.5 specific modules
if self.vision_in_dim is not None:
from comfy.ldm.wan.model import MLPProj
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
else:
self.vision_in = None
if self.use_cond_type_embedding:
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
else:
self.cond_type_embedding = None
def forward_orig( def forward_orig(
self, self,
img: Tensor, img: Tensor,
@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor, timesteps: Tensor,
y: Tensor = None, y: Tensor = None,
txt_byt5=None, txt_byt5=None,
clip_fea=None,
guidance: Tensor = None, guidance: Tensor = None,
guiding_frame_index=None, guiding_frame_index=None,
ref_latent=None, ref_latent=None,
@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module):
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
if self.cond_type_embedding is not None:
self.cond_type_embedding.to(txt.device)
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
txt = txt + cond_emb.to(txt.dtype)
if self.byt5_in is not None and txt_byt5 is not None: if self.byt5_in is not None and txt_byt5 is not None:
txt_byt5 = self.byt5_in(txt_byt5) txt_byt5 = self.byt5_in(txt_byt5)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
else:
txt = torch.cat((txt, txt_byt5), dim=1)
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt = torch.cat((txt, txt_byt5), dim=1)
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
if clip_fea is not None:
txt_vision_states = self.vision_in(clip_fea)
if self.cond_type_embedding is not None:
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
txt_vision_states = txt_vision_states + cond_emb
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
ids = torch.cat((img_ids, txt_ids), dim=1) ids = torch.cat((img_ids, txt_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
@ -349,7 +389,10 @@ class HunyuanVideo(nn.Module):
attn_mask = None attn_mask = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -371,7 +414,10 @@ class HunyuanVideo(nn.Module):
img = torch.cat((img, txt), 1) img = torch.cat((img, txt), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}
@ -430,14 +476,14 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
return repeat(img_ids, "h w c -> b (h w) c", b=bs) return repeat(img_ids, "h w c -> b (h w) c", b=bs)
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self._forward,
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs): def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
bs = x.shape[0] bs = x.shape[0]
if len(self.patch_size) == 3: if len(self.patch_size) == 3:
img_ids = self.img_ids(x) img_ids = self.img_ids(x)
@ -445,5 +491,5 @@ class HunyuanVideo(nn.Module):
else: else:
img_ids = self.img_ids_2d(x) img_ids = self.img_ids_2d(x)
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
return out return out

View File

@ -0,0 +1,120 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
VideoConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int = 64,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(nn.Module):
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
ch = block_out_channels[0]
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
self.up = nn.ModuleList()
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt,
temb_channels=0,
conv_shortcut=False,
conv_op=VideoConv3d, norm_op=RMS_norm)
for j in range(num_res_blocks + 1)])
ch = tgt
self.up.append(stage)
self.norm_out = RMS_norm(ch)
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
def forward(self, z):
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for stage in self.up:
for blk in stage.block:
x = blk(x)
out = self.conv_out(F.silu(self.norm_out(x)))
return out
UPSAMPLERS = {
"720p": SRModel3DV2,
"1080p": Upsampler,
}
class HunyuanVideo15SRModel():
def __init__(self, model_type, config):
self.load_device = model_management.vae_device()
offload_device = model_management.vae_offload_device()
self.dtype = model_management.vae_dtype(self.load_device)
self.model_class = UPSAMPLERS.get(model_type)
self.model = self.model_class(**config).eval()
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=True)
def get_sd(self):
return self.model.state_dict()
def resample_latent(self, latent):
model_management.load_model_gpu(self.patcher)
return self.model(latent.to(self.load_device))

View File

@ -4,8 +4,40 @@ import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
@ -14,7 +46,7 @@ class RMS_norm(nn.Module):
self.gamma = nn.Parameter(torch.empty(shape)) self.gamma = nn.Parameter(torch.empty(shape))
def forward(self, x): def forward(self, x):
return F.normalize(x, dim=1) * self.scale * self.gamma return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
@ -27,11 +59,12 @@ class DnSmpl(nn.Module):
self.tds = tds self.tds = tds
self.gs = fct * ic // oc self.gs = fct * ic // oc
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tds else 1 r1 = 2 if self.tds else 1
h = self.conv(x) h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tds and self.refiner_vae and conv_carry_in is None:
if self.tds and self.refiner_vae:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
@ -39,14 +72,7 @@ class DnSmpl(nn.Module):
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
hf = torch.cat([hf, hf], dim=1) hf = torch.cat([hf, hf], dim=1)
hn = h[:, :, 1:, :, :] h = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nf = frms // r1
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :] xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape b, ci, f, ht, wd = xf.shape
@ -54,34 +80,32 @@ class DnSmpl(nn.Module):
xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
B, C, T, H, W = xf.shape B, C, T, H, W = xf.shape
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
xn = x[:, :, 1:, :, :] x = x[:, :, 1:, :, :]
b, ci, frms, ht, wd = xn.shape
nf = frms // r1
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = xn.shape
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nf = frms // r1 if h.shape[2] == 0:
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) return hf + xf
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
b, ci, frms, ht, wd = x.shape b, c, frms, ht, wd = h.shape
nf = frms // r1 nf = frms // r1
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
B, C, T, H, W = sc.shape
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
return h + sc b, ci, frms, ht, wd = x.shape
nf = frms // r1
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
B, C, T, H, W = x.shape
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
if self.tds and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
@ -94,11 +118,11 @@ class UpSmpl(nn.Module):
self.tus = tus self.tus = tus
self.rp = fct * oc // ic self.rp = fct * oc // ic
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
r1 = 2 if self.tus else 1 r1 = 2 if self.tus else 1
h = self.conv(x) h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
if self.tus and self.refiner_vae: if self.tus and self.refiner_vae and conv_carry_in is None:
hf = h[:, :, :1, :, :] hf = h[:, :, :1, :, :]
b, c, f, ht, wd = hf.shape b, c, f, ht, wd = hf.shape
nc = c // (2 * 2) nc = c // (2 * 2)
@ -107,14 +131,7 @@ class UpSmpl(nn.Module):
hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
hf = hf[:, : hf.shape[1] // 2] hf = hf[:, : hf.shape[1] // 2]
hn = h[:, :, 1:, :, :] h = h[:, :, 1:, :, :]
b, c, frms, ht, wd = hn.shape
nc = c // (r1 * 2 * 2)
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
h = torch.cat([hf, hn], dim=2)
xf = x[:, :, :1, :, :] xf = x[:, :, :1, :, :]
b, ci, f, ht, wd = xf.shape b, ci, f, ht, wd = xf.shape
@ -125,29 +142,43 @@ class UpSmpl(nn.Module):
xf = xf.permute(0, 3, 4, 5, 1, 6, 2) xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
xf = xf.reshape(b, nc, f, ht * 2, wd * 2) xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
xn = x[:, :, 1:, :, :] x = x[:, :, 1:, :, :]
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = xn.shape
nc = c // (r1 * 2 * 2)
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = torch.cat([xf, xn], dim=2)
else:
b, c, frms, ht, wd = h.shape
nc = c // (r1 * 2 * 2)
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = x.repeat_interleave(repeats=self.rp, dim=1) b, c, frms, ht, wd = h.shape
b, c, frms, ht, wd = sc.shape nc = c // (r1 * 2 * 2)
nc = c // (r1 * 2 * 2) h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
return h + sc x = x.repeat_interleave(repeats=self.rp, dim=1)
b, c, frms, ht, wd = x.shape
nc = c // (r1 * 2 * 2)
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
if self.tus and self.refiner_vae and conv_carry_in is None:
h = torch.cat([hf, h], dim=2)
x = torch.cat([xf, x], dim=2)
return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
@ -160,7 +191,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = VideoConv3d conv_op = NoPadConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -175,10 +206,9 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, conv_op=conv_op, norm_op=norm_op)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -188,9 +218,9 @@ class Encoder(nn.Module):
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch) self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -201,31 +231,50 @@ class Encoder(nn.Module):
if not self.refiner_vae and x.shape[2] == 1: if not self.refiner_vae and x.shape[2] == 1:
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
x = self.conv_in(x) if self.refiner_vae:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.ffactor_temporal:
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
x = xl
else:
x = [x]
out = []
for stage in self.down: conv_carry_in = None
for blk in stage.block:
x = blk(x)
if hasattr(stage, 'downsample'):
x = stage.downsample(x)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
b, c, t, h, w = x.shape b, c, t, h, w = x.shape
grp = c // (self.z_channels << 1) grp = c // (self.z_channels << 1)
skip = x.view(b, c // grp, grp, t, h, w).mean(2) skip = x.view(b, c // grp, grp, t, h, w).mean(2)
out = self.conv_out(F.silu(self.norm_out(x))) + skip out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
if self.refiner_vae: if self.refiner_vae:
out = self.regul(out)[0] out = self.regul(out)[0]
out = torch.cat((out[:, :, :1], out), dim=2)
out = out.permute(0, 2, 1, 3, 4)
b, f_times_2, c, h, w = out.shape
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
out = out.permute(0, 2, 1, 3, 4).contiguous()
return out return out
class Decoder(nn.Module): class Decoder(nn.Module):
@ -239,7 +288,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = VideoConv3d conv_op = NoPadConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -249,9 +298,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -259,10 +308,9 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0, conv_op=conv_op, norm_op=norm_op)
conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
if i < depth: if i < depth:
@ -275,27 +323,41 @@ class Decoder(nn.Module):
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
def forward(self, z): def forward(self, z):
if self.refiner_vae: x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
z = z.permute(0, 2, 1, 3, 4)
b, f, c, h, w = z.shape
z = z.reshape(b, f, 2, c // 2, h, w)
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
z = z.permute(0, 2, 1, 3, 4)
z = z[:, :, 1:]
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
for stage in self.up: if self.refiner_vae:
for blk in stage.block: x = torch.split(x, 2, dim=2)
x = blk(x) else:
if hasattr(stage, 'upsample'): x = [ x ]
x = stage.upsample(x) out = []
out = self.conv_out(F.silu(self.norm_out(x))) conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
x1 = [ F.silu(self.norm_out(x1)) ]
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
del x
if len(out) > 1:
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1:
out = out[:, :, -1:] out = out[:, :, -1:]
return out return out

View File

@ -3,12 +3,11 @@ from torch import nn
import comfy.patcher_extension import comfy.patcher_extension
import comfy.ldm.modules.attention import comfy.ldm.modules.attention
import comfy.ldm.common_dit import comfy.ldm.common_dit
from einops import rearrange
import math import math
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
from comfy.ldm.flux.math import apply_rope1
def get_timestep_embedding( def get_timestep_embedding(
timesteps: torch.Tensor, timesteps: torch.Tensor,
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
return self.net(x) return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
k = self.k_norm(k) k = self.k_norm(k)
if pe is not None: if pe is not None:
q = apply_rotary_emb(q, pe) q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rotary_emb(k, pe) k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp y = comfy.ldm.common_dit.rms_norm(x)
x += self.ff(y) * gate_mlp y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)
return x return x
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype dtype = torch.float32
device = indices_grid.device
# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos) fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
start = 1 # Compute frequencies and apply cos/sin
end = theta freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
device = fractional_positions.device cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
indices = theta ** ( # Pad if dim is not divisible by 6
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0: if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) padding_size = dim % 6
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype) # Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
return freqs_cis
class LTXVModel(torch.nn.Module): class LTXVModel(torch.nn.Module):
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x) x = self.norm_out(x)
# Modulation # Modulation
x = x * (1 + scale) + shift x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x) x = self.proj_out(x)
x = self.patchifier.unpatchify( x = self.patchifier.unpatchify(

View File

@ -11,6 +11,7 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from comfy.ldm.modules.attention import optimized_attention_masked from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension import comfy.patcher_extension
@ -31,6 +32,7 @@ class JointAttention(nn.Module):
n_heads: int, n_heads: int,
n_kv_heads: Optional[int], n_kv_heads: Optional[int],
qk_norm: bool, qk_norm: bool,
out_bias: bool = False,
operation_settings={}, operation_settings={},
): ):
""" """
@ -59,7 +61,7 @@ class JointAttention(nn.Module):
self.out = operation_settings.get("operations").Linear( self.out = operation_settings.get("operations").Linear(
n_heads * self.head_dim, n_heads * self.head_dim,
dim, dim,
bias=False, bias=out_bias,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
@ -70,35 +72,6 @@ class JointAttention(nn.Module):
else: else:
self.q_norm = self.k_norm = nn.Identity() self.q_norm = self.k_norm = nn.Identity()
@staticmethod
def apply_rotary_emb(
x_in: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
"""
Apply rotary embeddings to input tensors using the given frequency
tensor.
This function applies rotary embeddings to the given query 'xq' and
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors
contain rotary embeddings and are returned as real tensors.
Args:
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
exponentials.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
and key tensor with rotary embeddings.
"""
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@ -134,8 +107,7 @@ class JointAttention(nn.Module):
xq = self.q_norm(xq) xq = self.q_norm(xq)
xk = self.k_norm(xk) xk = self.k_norm(xk)
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) xq, xk = apply_rope(xq, xk, freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
n_rep = self.n_local_heads // self.n_local_kv_heads n_rep = self.n_local_heads // self.n_local_kv_heads
if n_rep >= 1: if n_rep >= 1:
@ -215,6 +187,8 @@ class JointTransformerBlock(nn.Module):
norm_eps: float, norm_eps: float,
qk_norm: bool, qk_norm: bool,
modulation=True, modulation=True,
z_image_modulation=False,
attn_out_bias=False,
operation_settings={}, operation_settings={},
) -> None: ) -> None:
""" """
@ -235,10 +209,10 @@ class JointTransformerBlock(nn.Module):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.head_dim = dim // n_heads self.head_dim = dim // n_heads
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
self.feed_forward = FeedForward( self.feed_forward = FeedForward(
dim=dim, dim=dim,
hidden_dim=4 * dim, hidden_dim=dim,
multiple_of=multiple_of, multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier, ffn_dim_multiplier=ffn_dim_multiplier,
operation_settings=operation_settings, operation_settings=operation_settings,
@ -252,16 +226,27 @@ class JointTransformerBlock(nn.Module):
self.modulation = modulation self.modulation = modulation
if modulation: if modulation:
self.adaLN_modulation = nn.Sequential( if z_image_modulation:
nn.SiLU(), self.adaLN_modulation = nn.Sequential(
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(dim, 1024), min(dim, 256),
4 * dim, 4 * dim,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
), ),
) )
else:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operation_settings.get("operations").Linear(
min(dim, 1024),
4 * dim,
bias=True,
device=operation_settings.get("device"),
dtype=operation_settings.get("dtype"),
),
)
def forward( def forward(
self, self,
@ -323,7 +308,7 @@ class FinalLayer(nn.Module):
The final layer of NextDiT. The final layer of NextDiT.
""" """
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
super().__init__() super().__init__()
self.norm_final = operation_settings.get("operations").LayerNorm( self.norm_final = operation_settings.get("operations").LayerNorm(
hidden_size, hidden_size,
@ -340,10 +325,15 @@ class FinalLayer(nn.Module):
dtype=operation_settings.get("dtype"), dtype=operation_settings.get("dtype"),
) )
if z_image_modulation:
min_mod = 256
else:
min_mod = 1024
self.adaLN_modulation = nn.Sequential( self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.SiLU(),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
min(hidden_size, 1024), min(hidden_size, min_mod),
hidden_size, hidden_size,
bias=True, bias=True,
device=operation_settings.get("device"), device=operation_settings.get("device"),
@ -373,12 +363,16 @@ class NextDiT(nn.Module):
n_heads: int = 32, n_heads: int = 32,
n_kv_heads: Optional[int] = None, n_kv_heads: Optional[int] = None,
multiple_of: int = 256, multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None, ffn_dim_multiplier: float = 4.0,
norm_eps: float = 1e-5, norm_eps: float = 1e-5,
qk_norm: bool = False, qk_norm: bool = False,
cap_feat_dim: int = 5120, cap_feat_dim: int = 5120,
axes_dims: List[int] = (16, 56, 56), axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (1, 512, 512), axes_lens: List[int] = (1, 512, 512),
rope_theta=10000.0,
z_image_modulation=False,
time_scale=1.0,
pad_tokens_multiple=None,
image_model=None, image_model=None,
device=None, device=None,
dtype=None, dtype=None,
@ -390,6 +384,8 @@ class NextDiT(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = in_channels self.out_channels = in_channels
self.patch_size = patch_size self.patch_size = patch_size
self.time_scale = time_scale
self.pad_tokens_multiple = pad_tokens_multiple
self.x_embedder = operation_settings.get("operations").Linear( self.x_embedder = operation_settings.get("operations").Linear(
in_features=patch_size * patch_size * in_channels, in_features=patch_size * patch_size * in_channels,
@ -411,6 +407,7 @@ class NextDiT(nn.Module):
norm_eps, norm_eps,
qk_norm, qk_norm,
modulation=True, modulation=True,
z_image_modulation=z_image_modulation,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_refiner_layers) for layer_id in range(n_refiner_layers)
@ -434,7 +431,7 @@ class NextDiT(nn.Module):
] ]
) )
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
self.cap_embedder = nn.Sequential( self.cap_embedder = nn.Sequential(
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
operation_settings.get("operations").Linear( operation_settings.get("operations").Linear(
@ -457,18 +454,24 @@ class NextDiT(nn.Module):
ffn_dim_multiplier, ffn_dim_multiplier,
norm_eps, norm_eps,
qk_norm, qk_norm,
z_image_modulation=z_image_modulation,
attn_out_bias=False,
operation_settings=operation_settings, operation_settings=operation_settings,
) )
for layer_id in range(n_layers) for layer_id in range(n_layers)
] ]
) )
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
if self.pad_tokens_multiple is not None:
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
assert (dim // n_heads) == sum(axes_dims) assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims self.axes_dims = axes_dims
self.axes_lens = axes_lens self.axes_lens = axes_lens
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
self.dim = dim self.dim = dim
self.n_heads = n_heads self.n_heads = n_heads
@ -503,96 +506,42 @@ class NextDiT(nn.Module):
bsz = len(x) bsz = len(x)
pH = pW = self.patch_size pH = pW = self.patch_size
device = x[0].device device = x[0].device
dtype = x[0].dtype
if cap_mask is not None: if self.pad_tokens_multiple is not None:
l_effective_cap_len = cap_mask.sum(dim=1).tolist() pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
else: cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
l_effective_cap_len = [num_tokens] * bsz
if cap_mask is not None and not torch.is_floating_point(cap_mask): cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
img_sizes = [(img.size(1), img.size(2)) for img in x] B, C, H, W = x.shape
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
max_seq_len = max( H_tokens, W_tokens = H // pH, W // pW
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
) x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
max_cap_len = max(l_effective_cap_len) x_pos_ids[:, :, 1] = torch.arange(H_tokens, dtype=torch.float32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
max_img_len = max(l_effective_img_len) x_pos_ids[:, :, 2] = torch.arange(W_tokens, dtype=torch.float32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) if self.pad_tokens_multiple is not None:
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
for i in range(bsz): freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
H, W = img_sizes[i]
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
# cap_freqs_cis_shape[1] = max_cap_len
cap_freqs_cis_shape[1] = cap_feats.shape[1]
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
img_freqs_cis_shape = list(freqs_cis.shape)
img_freqs_cis_shape[1] = max_img_len
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
# refine context # refine context
for layer in self.context_refiner: for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
# refine image padded_img_mask = None
flat_x = []
for i in range(bsz):
img = x[i]
C, H, W = img.size()
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
flat_x.append(img)
x = flat_x
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
for i in range(bsz):
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
padded_img_embed = self.x_embedder(padded_img_embed)
padded_img_mask = padded_img_mask.unsqueeze(1)
for layer in self.noise_refiner: for layer in self.noise_refiner:
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
if cap_mask is not None:
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
else:
mask = None
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
for i in range(bsz):
cap_len = l_effective_cap_len[i]
img_len = l_effective_img_len[i]
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
padded_full_embed = torch.cat((cap_feats, x), dim=1)
mask = None
img_sizes = [(H, W)] * bsz
l_effective_cap_len = [cap_feats.shape[1]] * bsz
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
@ -615,7 +564,7 @@ class NextDiT(nn.Module):
y: (N,) tensor of text tokens/features y: (N,) tensor of text tokens/features
""" """
t = self.t_embedder(t, dtype=x.dtype) # (N, D) t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
adaln_input = t adaln_input = t
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute

View File

@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
from comfy.ldm.util import get_obj_from_str, instantiate_from_config from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
from einops import rearrange
import comfy.model_management
class DiagonalGaussianRegularizer(torch.nn.Module): class DiagonalGaussianRegularizer(torch.nn.Module):
def __init__(self, sample: bool = False): def __init__(self, sample: bool = False):
@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim self.embed_dim = embed_dim
if ddconfig.get("batch_norm_latent", False):
self.bn_eps = 1e-4
self.bn_momentum = 0.1
self.ps = [2, 2]
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
eps=self.bn_eps,
momentum=self.bn_momentum,
affine=False,
track_running_stats=True,
)
self.bn.eval()
else:
self.bn = None
def get_autoencoder_params(self) -> list: def get_autoencoder_params(self) -> list:
params = super().get_autoencoder_params() params = super().get_autoencoder_params()
return params return params
@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
z = torch.cat(z, 0) z = torch.cat(z, 0)
z, reg_log = self.regularization(z) z, reg_log = self.regularization(z)
if self.bn is not None:
z = rearrange(z,
"... c (i pi) (j pj) -> ... (c pi pj) i j",
pi=self.ps[0],
pj=self.ps[1],
)
z = torch.nn.functional.batch_norm(z,
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
momentum=self.bn_momentum,
eps=self.bn_eps)
if return_reg_log: if return_reg_log:
return z, reg_log return z, reg_log
return z return z
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
if self.bn is not None:
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
z = z * s + m
z = rearrange(
z,
"... (c pi pj) i j -> ... c (i pi) (j pj)",
pi=self.ps[0],
pj=self.ps[1],
)
if self.max_batch_size is None: if self.max_batch_size is None:
dec = self.post_quant_conv(z) dec = self.post_quant_conv(z)
dec = self.decoder(dec, **decoder_kwargs) dec = self.decoder(dec, **decoder_kwargs)

View File

@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations. Embeds scalar timesteps into vector representations.
""" """
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
if output_size is None:
output_size = hidden_size
self.mlp = nn.Sequential( self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
) )
self.frequency_embedding_size = frequency_embedding_size self.frequency_embedding_size = frequency_embedding_size

View File

@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)

View File

@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.patcher_extension import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1
class GELU(nn.Module): class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None): def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
@ -134,33 +135,34 @@ class Attention(nn.Module):
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={}, transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1]
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) # Project and reshape to BHND format (batch, heads, seq, dim)
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
img_query = self.norm_q(img_query) img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key) img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query) txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key) txt_key = self.norm_added_k(txt_key)
joint_query = torch.cat([txt_query, img_query], dim=1) joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=1) joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=1) joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rotary_emb(joint_query, image_rotary_emb) joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb) joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_query = joint_query.flatten(start_dim=2) joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
joint_key = joint_key.flatten(start_dim=2) attention_mask, transformer_options=transformer_options,
joint_value = joint_value.flatten(start_dim=2) skip_reshape=True)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
txt_attn_output = joint_hidden_states[:, :seq_txt, :] txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :] img_attn_output = joint_hidden_states[:, seq_txt:, :]
@ -234,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
img_normed = self.img_norm1(hidden_states) img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) del img_mod1
txt_normed = self.txt_norm1(encoder_hidden_states) txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) del txt_mod1
img_attn_output, txt_attn_output = self.attn( img_attn_output, txt_attn_output = self.attn(
hidden_states=img_modulated, hidden_states=img_modulated,
@ -246,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
transformer_options=transformer_options, transformer_options=transformer_options,
) )
del img_modulated
del txt_modulated
hidden_states = hidden_states + img_gate1 * img_attn_output hidden_states = hidden_states + img_gate1 * img_attn_output
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
del img_attn_output
del txt_attn_output
del img_gate1
del txt_gate1
img_normed2 = self.img_norm2(hidden_states) img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
txt_normed2 = self.txt_norm2(encoder_hidden_states) txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
@ -413,7 +419,7 @@ class QwenImageTransformer2DModel(nn.Module):
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
@ -433,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
out = {} out = {}

View File

@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
# assert e[0].dtype == torch.float32 # assert e[0].dtype == torch.float32
# self-attention # self-attention
x = x.contiguous() # otherwise implicit in LayerNorm
y = self.self_attn( y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options) freqs, transformer_options=transformer_options)
@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
if steps_w is None: if steps_w is None:
steps_w = w_len steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs: if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1 t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):

View File

@ -135,7 +135,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
fp8 = model_config.optimizations.get("fp8", False) fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
else: else:
operations = model_config.custom_operations operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@ -198,8 +198,14 @@ class BaseModel(torch.nn.Module):
extra_conds[o] = extra extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds) t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() if "latent_shapes" in extra_conds:
return self.model_sampling.calculate_denoised(sigma, model_output, x) xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
if len(model_output) > 1 and not torch.is_tensor(model_output):
model_output, _ = utils.pack_latents(model_output)
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
def process_timestep(self, timestep, **kwargs): def process_timestep(self, timestep, **kwargs):
return timestep return timestep
@ -335,6 +341,14 @@ class BaseModel(torch.nn.Module):
if self.model_config.scaled_fp8 is not None: if self.model_config.scaled_fp8 is not None:
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
# Save mixed precision metadata
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
metadata = {
"format_version": "1.0",
"layers": self.model_config.layer_quant_config
}
unet_state_dict["_quantization_metadata"] = metadata
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
if self.model_type == ModelType.V_PREDICTION: if self.model_type == ModelType.V_PREDICTION:
@ -892,12 +906,13 @@ class Flux(BaseModel):
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None: if attention_mask is not None:
shape = kwargs["noise"].shape shape = kwargs["noise"].shape
mask_ref_size = kwargs["attention_mask_img_shape"] mask_ref_size = kwargs.get("attention_mask_img_shape", None)
# the model will pad to the patch size, and then divide if mask_ref_size is not None:
# essentially dividing and rounding up # the model will pad to the patch size, and then divide
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) # essentially dividing and rounding up
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
guidance = kwargs.get("guidance", 3.5) guidance = kwargs.get("guidance", 3.5)
if guidance is not None: if guidance is not None:
@ -919,9 +934,19 @@ class Flux(BaseModel):
out = {} out = {}
ref_latents = kwargs.get("reference_latents", None) ref_latents = kwargs.get("reference_latents", None)
if ref_latents is not None: if ref_latents is not None:
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out return out
class Flux2(Flux):
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
target_text_len = 512
if cross_attn.shape[1] < target_text_len:
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class GenmoMochi(BaseModel): class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1097,9 +1122,13 @@ class Lumina2(BaseModel):
if torch.numel(attention_mask) != attention_mask.sum(): if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item())) out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
cross_attn = kwargs.get("cross_attn", None) cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None: if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
if 'num_tokens' not in out:
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
return out return out
class WAN21(BaseModel): class WAN21(BaseModel):
@ -1530,3 +1559,94 @@ class HunyuanImage21Refiner(HunyuanImage21):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(True) out['disable_time_r'] = comfy.conds.CONDConstant(True)
return out return out
class HunyuanVideo15(HunyuanVideo):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
if extra_channels == 0:
return None
image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]
if image is None:
shape_image = list(noise.shape)
shape_image[1] = extra_channels
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
else:
latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
for i in range(0, image.shape[1], latent_dim):
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.zeros_like(noise)[:, :1]
else:
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
if conditioning_byt5small is not None:
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
guidance = kwargs.get("guidance", 6.0)
if guidance is not None:
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
return out
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
image = utils.resize_to_batch_size(image, noise.shape[0])
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
else:
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
return image
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False)
return out

View File

@ -6,6 +6,20 @@ import math
import logging import logging
import torch import torch
def detect_layer_quantization(metadata):
quant_key = "_quantization_metadata"
if metadata is not None and quant_key in metadata:
quant_metadata = metadata.pop(quant_key)
quant_metadata = json.loads(quant_metadata)
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
return quant_metadata["layers"]
else:
raise ValueError("Invalid quantization metadata format")
return None
def count_blocks(state_dict_keys, prefix_string): def count_blocks(state_dict_keys, prefix_string):
count = 0 count = 0
while True: while True:
@ -172,30 +186,68 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
dit_config["guidance_embed"] = len(guidance_keys) > 0 dit_config["guidance_embed"] = len(guidance_keys) > 0
# HunyuanVideo 1.5
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
dit_config["use_cond_type_embedding"] = True
else:
dit_config["use_cond_type_embedding"] = False
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
else:
dit_config["vision_in_dim"] = None
return dit_config return dit_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {} dit_config = {}
dit_config["image_model"] = "flux" if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
dit_config["axes_dim"] = [32, 32, 32, 32]
dit_config["num_heads"] = 48
dit_config["mlp_ratio"] = 3.0
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
patch_size = 1
else:
dit_config["image_model"] = "flux"
dit_config["axes_dim"] = [16, 56, 56]
dit_config["num_heads"] = 24
dit_config["mlp_ratio"] = 4.0
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
patch_size = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
patch_size = 2 dit_config["hidden_size"] = 3072
dit_config["context_in_dim"] = 4096
dit_config["patch_size"] = patch_size dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix) in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys: if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) w = state_dict[in_key]
dit_config["out_channels"] = 16 dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
dit_config["hidden_size"] = w.shape[0]
txt_in_key = "{}txt_in.weight".format(key_prefix)
if txt_in_key in state_dict_keys:
w = state_dict[txt_in_key]
dit_config["context_in_dim"] = w.shape[1]
dit_config["hidden_size"] = w.shape[0]
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys: if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
dit_config["image_model"] = "chroma" dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64 dit_config["in_channels"] = 64
@ -364,14 +416,31 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["image_model"] = "lumina2" dit_config["image_model"] = "lumina2"
dit_config["patch_size"] = 2 dit_config["patch_size"] = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
dit_config["dim"] = 2304 w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] dit_config["dim"] = w.shape[0]
dit_config["cap_feat_dim"] = w.shape[1]
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["qk_norm"] = True dit_config["qk_norm"] = True
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512] if dit_config["dim"] == 2304: # Original Lumina 2
dit_config["n_heads"] = 24
dit_config["n_kv_heads"] = 8
dit_config["axes_dims"] = [32, 32, 32]
dit_config["axes_lens"] = [300, 512, 512]
dit_config["rope_theta"] = 10000.0
dit_config["ffn_dim_multiplier"] = 4.0
elif dit_config["dim"] == 3840: # Z image
dit_config["n_heads"] = 30
dit_config["n_kv_heads"] = 30
dit_config["axes_dims"] = [32, 48, 48]
dit_config["axes_lens"] = [1536, 512, 512]
dit_config["rope_theta"] = 256.0
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
dit_config["z_image_modulation"] = True
dit_config["time_scale"] = 1000.0
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
dit_config["pad_tokens_multiple"] = 32
return dit_config return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@ -701,6 +770,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
else: else:
model_config.optimizations["fp8"] = True model_config.optimizations["fp8"] = True
# Detect per-layer quantization (mixed precision)
layer_quant_config = detect_layer_quantization(metadata)
if layer_quant_config:
model_config.layer_quant_config = layer_quant_config
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
return model_config return model_config
def unet_prefix_from_state_dict(state_dict): def unet_prefix_from_state_dict(state_dict):

View File

@ -101,6 +101,7 @@ if args.deterministic:
directml_enabled = False directml_enabled = False
if args.directml is not None: if args.directml is not None:
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
import torch_directml import torch_directml
directml_enabled = True directml_enabled = True
device_index = args.directml device_index = args.directml
@ -515,6 +516,7 @@ class LoadedModel:
if use_more_vram == 0: if use_more_vram == 0:
use_more_vram = 1e32 use_more_vram = 1e32
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
real_model = self.model.model real_model = self.model.model
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
@ -745,7 +747,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
current_free_mem = get_free_memory(torch_dev) + loaded_memory current_free_mem = get_free_memory(torch_dev) + loaded_memory
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) lowvram_model_memory = lowvram_model_memory - loaded_memory
if lowvram_model_memory == 0:
lowvram_model_memory = 0.1
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1 lowvram_model_memory = 0.1
@ -1055,12 +1060,6 @@ def device_supports_non_blocking(device):
return False return False
return True return True
def device_should_use_non_blocking(device):
if not device_supports_non_blocking(device):
return False
return False
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
def force_channels_last(): def force_channels_last():
if args.force_channels_last: if args.force_channels_last:
return True return True
@ -1075,6 +1074,16 @@ if args.async_offload:
NUM_STREAMS = 2 NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS)) logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
def current_stream(device):
if device is None:
return None
if is_device_cuda(device):
return torch.cuda.current_stream()
elif is_device_xpu(device):
return torch.xpu.current_stream()
else:
return None
stream_counters = {} stream_counters = {}
def get_offload_stream(device): def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0) stream_counter = stream_counters.get(device, 0)
@ -1083,21 +1092,17 @@ def get_offload_stream(device):
if device in STREAMS: if device in STREAMS:
ss = STREAMS[device] ss = STREAMS[device]
s = ss[stream_counter] #Sync the oldest stream in the queue with the current
ss[stream_counter].wait_stream(current_stream(device))
stream_counter = (stream_counter + 1) % len(ss) stream_counter = (stream_counter + 1) % len(ss)
if is_device_cuda(device):
ss[stream_counter].wait_stream(torch.cuda.current_stream())
elif is_device_xpu(device):
ss[stream_counter].wait_stream(torch.xpu.current_stream())
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return ss[stream_counter]
elif is_device_cuda(device): elif is_device_cuda(device):
ss = [] ss = []
for k in range(NUM_STREAMS): for k in range(NUM_STREAMS):
ss.append(torch.cuda.Stream(device=device, priority=0)) ss.append(torch.cuda.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
elif is_device_xpu(device): elif is_device_xpu(device):
@ -1106,18 +1111,14 @@ def get_offload_stream(device):
ss.append(torch.xpu.Stream(device=device, priority=0)) ss.append(torch.xpu.Stream(device=device, priority=0))
STREAMS[device] = ss STREAMS[device] = ss
s = ss[stream_counter] s = ss[stream_counter]
stream_counter = (stream_counter + 1) % len(ss)
stream_counters[device] = stream_counter stream_counters[device] = stream_counter
return s return s
return None return None
def sync_stream(device, stream): def sync_stream(device, stream):
if stream is None: if stream is None or current_stream(device) is None:
return return
if is_device_cuda(device): current_stream(device).wait_stream(stream)
torch.cuda.current_stream().wait_stream(stream)
elif is_device_xpu(device):
torch.xpu.current_stream().wait_stream(stream)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device: if device is None or weight.device == device:
@ -1142,6 +1143,83 @@ def cast_to_device(tensor, device, dtype, copy=False):
non_blocking = device_supports_non_blocking(device) non_blocking = device_supports_non_blocking(device)
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
PINNED_MEMORY = {}
TOTAL_PINNED_MEMORY = 0
MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory:
if is_nvidia() or is_amd():
if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
def pin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
return False
if not is_device_cpu(tensor.device):
return False
if tensor.is_pinned():
#NOTE: Cuda does detect when a tensor is already pinned and would
#error below, but there are proven cases where this also queues an error
#on the GPU async. So dont trust the CUDA API and guard here
return False
if not tensor.is_contiguous():
return False
size = tensor.numel() * tensor.element_size()
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
return False
ptr = tensor.data_ptr()
if ptr == 0:
return False
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
PINNED_MEMORY[ptr] = size
TOTAL_PINNED_MEMORY += size
return True
return False
def unpin_memory(tensor):
global TOTAL_PINNED_MEMORY
if MAX_PINNED_MEMORY <= 0:
return False
if not is_device_cpu(tensor.device):
return False
ptr = tensor.data_ptr()
size = tensor.numel() * tensor.element_size()
size_stored = PINNED_MEMORY.get(ptr, None)
if size_stored is None:
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
return False
if size != size_stored:
logging.warning("Size of pinned tensor changed")
return False
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True
return False
def sage_attention_enabled(): def sage_attention_enabled():
return args.use_sage_attention return args.use_sage_attention

View File

@ -259,7 +259,7 @@ class LowVramPatch:
def __call__(self, weight): def __call__(self, weight):
intermediate_dtype = weight.dtype intermediate_dtype = weight.dtype
if self.convert_func is not None: if self.convert_func is not None:
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) weight = self.convert_func(weight, inplace=False)
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
intermediate_dtype = torch.float32 intermediate_dtype = torch.float32
@ -358,13 +358,13 @@ class ModelPatcher:
self.object_patches_backup = {} self.object_patches_backup = {}
self.weight_wrapper_patches = {} self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}} self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.force_cast_weights = False self.force_cast_weights = False
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
self.parent = None self.parent = None
self.pinned = set()
self.attachments: dict[str] = {} self.attachments: dict[str] = {}
self.additional_models: dict[str, list[ModelPatcher]] = {} self.additional_models: dict[str, list[ModelPatcher]] = {}
@ -402,6 +402,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -409,7 +412,7 @@ class ModelPatcher:
return self.model.lowvram_patch_counter return self.model.lowvram_patch_counter
def clone(self): def clone(self):
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@ -421,6 +424,7 @@ class ModelPatcher:
n.backup = self.backup n.backup = self.backup
n.object_patches_backup = self.object_patches_backup n.object_patches_backup = self.object_patches_backup
n.parent = self n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights n.force_cast_weights = self.force_cast_weights
@ -577,6 +581,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch): def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input") self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj
@ -746,6 +763,21 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def pin_weight_to_device(self, key):
weight, set_func, convert_func = get_key_weight(self.model, key)
if comfy.model_management.pin_memory(weight):
self.pinned.add(key)
def unpin_weight(self, key):
if key in self.pinned:
weight, set_func, convert_func = get_key_weight(self.model, key)
comfy.model_management.unpin_memory(weight)
self.pinned.remove(key)
def unpin_all_weights(self):
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self): def _load_list(self):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@ -767,9 +799,11 @@ class ModelPatcher:
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0 lowvram_counter = 0
lowvram_mem_counter = 0
loading = self._load_list() loading = self._load_list()
load_completely = [] load_completely = []
offloaded = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
n = x[1] n = x[1]
@ -786,6 +820,7 @@ class ModelPatcher:
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1 lowvram_counter += 1
lowvram_mem_counter += module_mem
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue continue
@ -811,6 +846,7 @@ class ModelPatcher:
patch_counter += 1 patch_counter += 1
cast_weight = True cast_weight = True
offloaded.append((module_mem, n, m, params))
else: else:
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
@ -841,7 +877,9 @@ class ModelPatcher:
continue continue
for param in params: for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True m.comfy_patched_weights = True
@ -849,11 +887,17 @@ class ModelPatcher:
for x in load_completely: for x in load_completely:
x[2].to(device_to) x[2].to(device_to)
for x in offloaded:
n = x[1]
params = x[3]
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
if full_load: if full_load:
self.model.to(device_to) self.model.to(device_to)
@ -890,6 +934,7 @@ class ModelPatcher:
self.eject_model() self.eject_model()
if unpatch_weights: if unpatch_weights:
self.unpatch_hooks() self.unpatch_hooks()
self.unpin_all_weights()
if self.model.model_lowvram: if self.model.model_lowvram:
for m in self.model.modules(): for m in self.model.modules():
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
@ -931,7 +976,7 @@ class ModelPatcher:
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False hooks_unpatched = False
memory_freed = 0 memory_freed = 0
@ -982,13 +1027,19 @@ class ModelPatcher:
module_mem += move_weight_functions(m, device_to) module_mem += move_weight_functions(m, device_to)
if lowvram_possible: if lowvram_possible:
if weight_key in self.patches: if weight_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, weight_key) if force_patch_weights:
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) self.patch_weight_to_device(weight_key)
patch_counter += 1 else:
_, set_func, convert_func = get_key_weight(self.model, weight_key)
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
patch_counter += 1
if bias_key in self.patches: if bias_key in self.patches:
_, set_func, convert_func = get_key_weight(self.model, bias_key) if force_patch_weights:
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) self.patch_weight_to_device(bias_key)
patch_counter += 1 else:
_, set_func, convert_func = get_key_weight(self.model, bias_key)
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
patch_counter += 1
cast_weight = True cast_weight = True
if cast_weight: if cast_weight:
@ -998,9 +1049,13 @@ class ModelPatcher:
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
@ -1013,6 +1068,9 @@ class ModelPatcher:
extra_memory += (used - self.model.model_loaded_weight_memory) extra_memory += (used - self.model.model_loaded_weight_memory)
self.patch_model(load_weights=False) self.patch_model(load_weights=False)
if extra_memory < 0 and not unpatch_weights:
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
return 0
full_load = False full_load = False
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0: if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
self.apply_hooks(self.forced_hooks, force_apply=True) self.apply_hooks(self.forced_hooks, force_apply=True)
@ -1400,5 +1458,6 @@ class ModelPatcher:
self.clear_cached_hook_weights() self.clear_cached_hook_weights()
def __del__(self): def __del__(self):
self.unpin_all_weights()
self.detach(unpatch_all=False) self.detach(unpatch_all=False)

91
comfy/nested_tensor.py Normal file
View File

@ -0,0 +1,91 @@
import torch
class NestedTensor:
def __init__(self, tensors):
self.tensors = list(tensors)
self.is_nested = True
def _copy(self):
return NestedTensor(self.tensors)
def apply_operation(self, other, operation):
o = self._copy()
if isinstance(other, NestedTensor):
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other.tensors[i])
else:
for i, t in enumerate(o.tensors):
o.tensors[i] = operation(t, other)
return o
def __add__(self, b):
return self.apply_operation(b, lambda x, y: x + y)
def __sub__(self, b):
return self.apply_operation(b, lambda x, y: x - y)
def __mul__(self, b):
return self.apply_operation(b, lambda x, y: x * y)
# def __itruediv__(self, b):
# return self.apply_operation(b, lambda x, y: x / y)
def __truediv__(self, b):
return self.apply_operation(b, lambda x, y: x / y)
def __getitem__(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
def unbind(self):
return self.tensors
def to(self, *args, **kwargs):
o = self._copy()
for i, t in enumerate(o.tensors):
o.tensors[i] = t.to(*args, **kwargs)
return o
def new_ones(self, *args, **kwargs):
return self.tensors[0].new_ones(*args, **kwargs)
def float(self):
return self.to(dtype=torch.float)
def chunk(self, *args, **kwargs):
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
def size(self):
return self.tensors[0].size()
@property
def shape(self):
return self.tensors[0].shape
@property
def ndim(self):
dims = 0
for t in self.tensors:
dims = max(t.ndim, dims)
return dims
@property
def device(self):
return self.tensors[0].device
@property
def dtype(self):
return self.tensors[0].dtype
@property
def layout(self):
return self.tensors[0].layout
def cat_nested(tensors, *args, **kwargs):
cated_tensors = []
for i in range(len(tensors[0].tensors)):
tens = []
for j in range(len(tensors)):
tens.append(tensors[j].tensors[i])
cated_tensors.append(torch.cat(tens, *args, **kwargs))
return NestedTensor(cated_tensors)

View File

@ -35,7 +35,7 @@ def scaled_dot_product_attention(q, k, v, *args, **kwargs):
try: try:
if torch.cuda.is_available(): if torch.cuda.is_available() and comfy.model_management.WINDOWS:
from torch.nn.attention import SDPBackend, sdpa_kernel from torch.nn.attention import SDPBackend, sdpa_kernel
import inspect import inspect
if "set_priority" in inspect.signature(sdpa_kernel).parameters: if "set_priority" in inspect.signature(sdpa_kernel).parameters:
@ -58,7 +58,8 @@ except (ModuleNotFoundError, TypeError):
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try: try:
if comfy.model_management.is_nvidia(): if comfy.model_management.is_nvidia():
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10): cudnn_version = torch.backends.cudnn.version()
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed' #TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.") logging.info("working around nvidia conv3d memory bug.")
@ -70,42 +71,76 @@ cast_to = comfy.model_management.cast_to #TODO: remove once no more references
def cast_to_input(weight, input, non_blocking=False, copy=True): def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@torch.compiler.disable()
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
if input is not None: if input is not None:
if dtype is None: if dtype is None:
dtype = input.dtype if isinstance(input, QuantizedTensor):
dtype = input._layout_params["orig_dtype"]
else:
dtype = input.dtype
if bias_dtype is None: if bias_dtype is None:
bias_dtype = dtype bias_dtype = dtype
if device is None: if device is None:
device = input.device device = input.device
offload_stream = comfy.model_management.get_offload_stream(device) if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device)
else:
offload_stream = None
if offload_stream is not None: if offload_stream is not None:
wf_context = offload_stream wf_context = offload_stream
else: else:
wf_context = contextlib.nullcontext() wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function: weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context: with wf_context:
for f in s.bias_function: for f in s.bias_function:
bias = f(bias) bias = f(bias)
has_function = len(s.weight_function) > 0 if weight_has_function or weight.dtype != dtype:
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
with wf_context: with wf_context:
weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream) comfy.model_management.sync_stream(device, offload_stream)
return weight, bias if offloadable:
return weight, bias, offload_stream
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
if weight is not None:
device = weight.device
else:
if bias is None:
return
device = bias.device
offload_stream.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
@ -118,8 +153,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -133,8 +170,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -148,8 +187,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -172,8 +213,10 @@ class disable_weight_init:
return super()._conv_forward(input, weight, bias, *args, **kwargs) return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return self._conv_forward(input, weight, bias) x = self._conv_forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -187,8 +230,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -203,11 +248,14 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
bias = None bias = None
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) offload_stream = None
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -223,11 +271,15 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
else: else:
weight = None weight = None
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated bias = None
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -246,10 +298,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose2d( x = torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -268,10 +322,12 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.conv_transpose1d( x = torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation) output_padding, self.groups, self.dilation)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -289,8 +345,11 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
run_every_op() run_every_op()
@ -344,20 +403,18 @@ class manual_cast(disable_weight_init):
def fp8_linear(self, input): def fp8_linear(self, input):
"""
Legacy FP8 linear function for backward compatibility.
Uses QuantizedTensor subclass for dispatch.
"""
dtype = self.weight.dtype dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
return None return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) if input.ndim == 3 or input.ndim == 2:
w = w.t() w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
scale_input = self.scale_input scale_input = self.scale_input
@ -369,23 +426,20 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
if bias is not None: # Wrap weight in QuantizedTensor - this enables unified dispatch
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
else: layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
if isinstance(o, tuple): uncast_bias_weight(self, w, bias, offload_stream)
o = o[0] return o
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None
@ -405,8 +459,10 @@ class fp8_ops(manual_cast):
except Exception as e: except Exception as e:
logging.info("Exception during fp8 op: {}".format(e)) logging.info("Exception during fp8 op: {}".format(e))
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
return torch.nn.functional.linear(input, weight, bias) x = torch.nn.functional.linear(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
@ -434,19 +490,21 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
if out is not None: if out is not None:
return out return out
weight, bias = cast_bias_weight(self, input) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
if weight.numel() < input.numel(): #TODO: optimize if weight.numel() < input.numel(): #TODO: optimize
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
else: else:
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def convert_weight(self, weight, inplace=False, **kwargs): def convert_weight(self, weight, inplace=False, **kwargs):
if inplace: if inplace:
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
return weight return weight
else: else:
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32)
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
@ -478,8 +536,142 @@ if CUBLAS_IS_AVAILABLE:
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # ==============================================================================
# Mixed Precision Operations
# ==============================================================================
from .quant_ops import QuantizedTensor, QUANT_ALGOS
def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
class MixedPrecisionOps(manual_cast):
_layer_quant_config = layer_quant_config
_compute_dtype = compute_dtype
_full_precision_mm = full_precision_mm
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
self.tensor_class = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
def reset_parameters(self):
return None
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
raise ValueError(f"Missing weight for layer {layer_name}")
manually_loaded_keys = [weight_key]
if layer_name not in MixedPrecisionOps._layer_quant_config:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
if quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
weight_scale_key = f"{prefix}weight_scale"
layout_params = {
'scale': state_dict.pop(weight_scale_key, None),
'orig_dtype': MixedPrecisionOps._compute_dtype,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
manually_loaded_keys.append(weight_scale_key)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, input, *args, **kwargs):
run_every_op()
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(input, *args, **kwargs)
if (getattr(self, 'layout_type', None) is not None and
getattr(self, 'input_scale', None) is not None and
not isinstance(input, QuantizedTensor)):
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
return self._forward(input, self.weight, self.bias)
def convert_weight(self, weight, inplace=False, **kwargs):
if isinstance(weight, QuantizedTensor):
return weight.dequantize()
else:
return weight
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
if getattr(self, 'layout_type', None) is not None:
weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
else:
weight = weight.to(self.weight.dtype)
if return_weight:
return weight
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
self.weight = torch.nn.Parameter(weight, requires_grad=False)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute)
if scaled_fp8 is not None: if scaled_fp8 is not None:
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)

572
comfy/quant_ops.py Normal file
View File

@ -0,0 +1,572 @@
import torch
import logging
from typing import Tuple, Dict
import comfy.float
_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
def register_layout_op(torch_op, layout_type):
"""
Decorator to register a layout-specific operation handler.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
layout_type: Layout class (e.g., TensorCoreFP8Layout)
Example:
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
def fp8_linear(func, args, kwargs):
# FP8-specific linear implementation
...
"""
def decorator(handler_func):
if torch_op not in _LAYOUT_REGISTRY:
_LAYOUT_REGISTRY[torch_op] = {}
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
return handler_func
return decorator
def register_generic_util(torch_op):
"""
Decorator to register a generic utility that works for all layouts.
Args:
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
Example:
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
# Works for any layout
...
"""
def decorator(handler_func):
_GENERIC_UTILS[torch_op] = handler_func
return handler_func
return decorator
def _get_layout_from_args(args):
for arg in args:
if isinstance(arg, QuantizedTensor):
return arg._layout_type
elif isinstance(arg, (list, tuple)):
for item in arg:
if isinstance(item, QuantizedTensor):
return item._layout_type
return None
def _move_layout_params_to_device(params, device):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.to(device=device)
else:
new_params[k] = v
return new_params
def _copy_layout_params(params):
new_params = {}
for k, v in params.items():
if isinstance(v, torch.Tensor):
new_params[k] = v.clone()
else:
new_params[k] = v
return new_params
def _copy_layout_params_inplace(src, dst, non_blocking=False):
for k, v in src.items():
if isinstance(v, torch.Tensor):
dst[k].copy_(v, non_blocking=non_blocking)
else:
dst[k] = v
class QuantizedLayout:
"""
Base class for quantization layouts.
A layout encapsulates the format-specific logic for quantization/dequantization
and provides a uniform interface for extracting raw tensors needed for computation.
New quantization formats should subclass this and implement the required methods.
"""
@classmethod
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
@staticmethod
def dequantize(qdata, **layout_params) -> torch.Tensor:
raise NotImplementedError("TensorLayout must implement dequantize()")
@classmethod
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
class QuantizedTensor(torch.Tensor):
"""
Universal quantized tensor that works with any layout.
This tensor subclass uses a pluggable layout system to support multiple
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
The layout_type determines format-specific behavior, while common operations
(detach, clone, to) are handled generically.
Attributes:
_qdata: The quantized tensor data
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
"""
@staticmethod
def __new__(cls, qdata, layout_type, layout_params):
"""
Create a quantized tensor.
Args:
qdata: The quantized data tensor
layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters
"""
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata
self._layout_type = layout_type
self._layout_params = layout_params
def __repr__(self):
layout_name = self._layout_type
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
@property
def layout_type(self):
return self._layout_type
def __tensor_flatten__(self):
"""
Tensor flattening protocol for proper device movement.
"""
inner_tensors = ["_qdata"]
ctx = {
"layout_type": self._layout_type,
}
tensor_params = {}
non_tensor_params = {}
for k, v in self._layout_params.items():
if isinstance(v, torch.Tensor):
tensor_params[k] = v
else:
non_tensor_params[k] = v
ctx["tensor_param_keys"] = list(tensor_params.keys())
ctx["non_tensor_params"] = non_tensor_params
for k, v in tensor_params.items():
attr_name = f"_layout_param_{k}"
object.__setattr__(self, attr_name, v)
inner_tensors.append(attr_name)
return inner_tensors, ctx
@staticmethod
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
"""
Tensor unflattening protocol for proper device movement.
Reconstructs the QuantizedTensor after device movement.
"""
layout_type = ctx["layout_type"]
layout_params = dict(ctx["non_tensor_params"])
for key in ctx["tensor_param_keys"]:
attr_name = f"_layout_param_{key}"
layout_params[key] = inner_tensors[attr_name]
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
@classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor:
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
# Step 1: Check generic utilities first (detach, clone, to, etc.)
if func in _GENERIC_UTILS:
return _GENERIC_UTILS[func](func, args, kwargs)
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
layout_type = _get_layout_from_args(args)
if layout_type and func in _LAYOUT_REGISTRY:
handler = _LAYOUT_REGISTRY[func].get(layout_type)
if handler:
return handler(func, args, kwargs)
# Step 3: Fallback to dequantization
if isinstance(args[0] if args else None, QuantizedTensor):
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
return cls._dequant_and_fallback(func, args, kwargs)
@classmethod
def _dequant_and_fallback(cls, func, args, kwargs):
def dequant_arg(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize()
elif isinstance(arg, (list, tuple)):
return type(arg)(dequant_arg(a) for a in arg)
return arg
new_args = dequant_arg(args)
new_kwargs = dequant_arg(kwargs)
return func(*new_args, **new_kwargs)
def data_ptr(self):
return self._qdata.data_ptr()
def is_pinned(self):
return self._qdata.is_pinned()
def is_contiguous(self):
return self._qdata.is_contiguous()
# ==============================================================================
# Generic Utilities (Layout-Agnostic Operations)
# ==============================================================================
def _create_transformed_qtensor(qt, transform_fn):
new_data = transform_fn(qt._qdata)
new_params = _copy_layout_params(qt._layout_params)
return QuantizedTensor(new_data, qt._layout_type, new_params)
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
if target_dtype is not None and target_dtype != qt.dtype:
logging.warning(
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
f"but not supported for quantized tensors. Ignoring dtype."
)
if target_layout is not None and target_layout != torch.strided:
logging.warning(
f"QuantizedTensor: layout change requested to {target_layout}, "
f"but not supported. Ignoring layout."
)
# Handle device transfer
current_device = qt._qdata.device
if target_device is not None:
# Normalize device for comparison
if isinstance(target_device, str):
target_device = torch.device(target_device)
if isinstance(current_device, str):
current_device = torch.device(current_device)
if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
return new_qt
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
return qt
@register_generic_util(torch.ops.aten.detach.default)
def generic_detach(func, args, kwargs):
"""Detach operation - creates a detached copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.detach())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.clone.default)
def generic_clone(func, args, kwargs):
"""Clone operation - creates a deep copy of the quantized tensor."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _create_transformed_qtensor(qt, lambda x: x.clone())
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._to_copy.default)
def generic_to_copy(func, args, kwargs):
"""Device/dtype transfer operation - handles .to(device) calls."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
op_name="_to_copy"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype_layout)
def generic_to_dtype_layout(func, args, kwargs):
"""Handle .to(device) calls using the dtype_layout variant."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
return _handle_device_transfer(
qt,
target_device=kwargs.get('device', None),
target_dtype=kwargs.get('dtype', None),
target_layout=kwargs.get('layout', None),
op_name="to"
)
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.copy_.default)
def generic_copy_(func, args, kwargs):
qt_dest = args[0]
src = args[1]
non_blocking = args[2] if len(args) > 2 else False
if isinstance(qt_dest, QuantizedTensor):
if isinstance(src, QuantizedTensor):
# Copy from another quantized tensor
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
qt_dest._layout_type = src._layout_type
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
else:
# Copy from regular tensor - just copy raw data
qt_dest._qdata.copy_(src)
return qt_dest
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
return True
@register_generic_util(torch.ops.aten.empty_like.default)
def generic_empty_like(func, args, kwargs):
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
qt = args[0]
if isinstance(qt, QuantizedTensor):
# Create empty tensor with same shape and dtype as the quantized data
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
new_qdata = torch.empty_like(qt._qdata, **kwargs)
# Handle device transfer for layout params
target_device = kwargs.get('device', new_qdata.device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
# Update orig_dtype if dtype is specified
new_params['orig_dtype'] = hp_dtype
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
return func(*args, **kwargs)
# ==============================================================================
# FP8 Layout + Operation Handlers
# ==============================================================================
class TensorCoreFP8Layout(QuantizedLayout):
"""
Storage format:
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
- scale: Scalar tensor (float32) for dequantization
- orig_dtype: Original dtype before quantization (for casting back)
"""
@classmethod
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
orig_dtype = tensor.dtype
if scale is None:
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
if not isinstance(scale, torch.Tensor):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)
if inplace_ops:
tensor *= (1.0 / scale).to(tensor.dtype)
else:
tensor = tensor * (1.0 / scale).to(tensor.dtype)
if stochastic_rounding > 0:
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
else:
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
layout_params = {
'scale': scale,
'orig_dtype': orig_dtype
}
return tensor, layout_params
@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
return plain_tensor * scale
@classmethod
def get_plain_tensors(cls, qtensor):
return qtensor._qdata, qtensor._layout_params['scale']
QUANT_ALGOS = {
"float8_e4m3fn": {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreFP8Layout",
},
}
LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
bias = args[2] if len(args) > 2 else None
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
out_dtype = kwargs.get("out_dtype")
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
weight_t = plain_weight.t()
tensor_2d = False
if len(plain_input.shape) == 2:
tensor_2d = True
plain_input = plain_input.unsqueeze(1)
input_shape = plain_input.shape
if len(input_shape) != 3:
return None
try:
output = torch._scaled_mm(
plain_input.reshape(-1, input_shape[2]).contiguous(),
weight_t,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
if not tensor_2d:
output = output.reshape((-1, input_shape[1], weight.shape[0]))
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
output_scale = scale_a * scale_b
output_params = {
'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype']
}
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else:
return output
except Exception as e:
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
# Case 2: DQ Fallback
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
if isinstance(input_tensor, QuantizedTensor):
input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias)
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
if out_dtype is None:
out_dtype = input_tensor._layout_params['orig_dtype']
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
output = torch._scaled_mm(
plain_input.contiguous(),
plain_weight,
bias=bias,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
)
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
output = output[0]
return output
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
def fp8_addmm(func, args, kwargs):
input_tensor = args[1]
weight = args[2]
bias = args[0]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
if isinstance(args[2], QuantizedTensor):
a[2] = args[2].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
def fp8_mm(func, args, kwargs):
input_tensor = args[0]
weight = args[1]
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
a = list(args)
if isinstance(args[0], QuantizedTensor):
a[0] = args[0].dequantize()
if isinstance(args[1], QuantizedTensor):
a[1] = args[1].dequantize()
return func(*a, **kwargs)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -4,13 +4,9 @@ import comfy.samplers
import comfy.utils import comfy.utils
import numpy as np import numpy as np
import logging import logging
import comfy.nested_tensor
def prepare_noise(latent_image, seed, noise_inds=None): def prepare_noise_inner(latent_image, generator, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if noise_inds is None: if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
if i in unique_inds: if i in unique_inds:
noises.append(noise) noises.append(noise)
noises = [noises[i] for i in inverse] noises = [noises[i] for i in inverse]
noises = torch.cat(noises, axis=0) return torch.cat(noises, axis=0)
def prepare_noise(latent_image, seed, noise_inds=None):
"""
creates random noise given a latent image and a seed.
optional arg skip can be used to skip and discard x number of noise generations for a given seed
"""
generator = torch.manual_seed(seed)
if latent_image.is_nested:
tensors = latent_image.unbind()
noises = []
for t in tensors:
noises.append(prepare_noise_inner(t, generator, noise_inds))
noises = comfy.nested_tensor.NestedTensor(noises)
else:
noises = prepare_noise_inner(latent_image, generator, noise_inds)
return noises return noises
def fix_empty_latent_channels(model, latent_image): def fix_empty_latent_channels(model, latent_image):
if latent_image.is_nested:
return latent_image
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)

View File

@ -782,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
return KSAMPLER(sampler_function, extra_options, inpaint_options) return KSAMPLER(sampler_function, extra_options, inpaint_options)
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
for k in conds: for k in conds:
conds[k] = conds[k][:] conds[k] = conds[k][:]
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
@ -792,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
if hasattr(model, 'extra_conds'): if hasattr(model, 'extra_conds'):
for k in conds: for k in conds:
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
#make sure each cond area has an opposite one with the same area #make sure each cond area has an opposite one with the same area
for k in conds: for k in conds:
@ -962,11 +962,11 @@ class CFGGuider:
def predict_noise(self, x, timestep, model_options={}, seed=None): def predict_noise(self, x, timestep, model_options={}, seed=None):
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed): def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image. if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
latent_image = self.inner_model.process_latent_in(latent_image) latent_image = self.inner_model.process_latent_in(latent_image)
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options) extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
@ -980,7 +980,7 @@ class CFGGuider:
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return self.inner_model.process_latent_out(samples.to(torch.float32)) return self.inner_model.process_latent_out(samples.to(torch.float32))
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
@ -994,7 +994,7 @@ class CFGGuider:
try: try:
self.model_patcher.pre_run() self.model_patcher.pre_run()
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally: finally:
self.model_patcher.cleanup() self.model_patcher.cleanup()
@ -1007,6 +1007,12 @@ class CFGGuider:
if sigmas.shape[-1] == 0: if sigmas.shape[-1] == 0:
return latent_image return latent_image
if latent_image.is_nested:
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
noise, _ = comfy.utils.pack_latents(noise.unbind())
else:
latent_shapes = [latent_image.shape]
self.conds = {} self.conds = {}
for k in self.original_conds: for k in self.original_conds:
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
@ -1026,7 +1032,7 @@ class CFGGuider:
self, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
) )
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
finally: finally:
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device) cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
self.model_options = orig_model_options self.model_options = orig_model_options
@ -1034,6 +1040,9 @@ class CFGGuider:
self.model_patcher.restore_hook_patches() self.model_patcher.restore_hook_patches()
del self.conds del self.conds
if len(latent_shapes) > 1:
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
return output return output

View File

@ -52,6 +52,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -143,6 +144,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -293,6 +297,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False self.disable_offload = False
self.not_video = False self.not_video = False
self.size = None
self.downscale_index_formula = None self.downscale_index_formula = None
self.upscale_index_formula = None self.upscale_index_formula = None
@ -352,7 +357,7 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
elif sd['decoder.conv_in.weight'].shape[1] == 32: elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
@ -378,6 +383,17 @@ class VAE:
self.upscale_ratio = 4 self.upscale_ratio = 4
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
if 'decoder.post_quant_conv.weight' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
if 'bn.running_mean' in sd:
ddconfig["batch_norm_latent"] = True
self.downscale_ratio *= 2
self.upscale_ratio *= 2
self.latent_channels *= 4
old_memory_used_decode = self.memory_used_decode
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
if 'post_quant_conv.weight' in sd: if 'post_quant_conv.weight' in sd:
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
else: else:
@ -437,20 +453,20 @@ class VAE:
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.latent_channels = 64 self.latent_channels = 32
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
self.upscale_index_formula = (4, 16, 16) self.upscale_index_formula = (4, 16, 16)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
self.downscale_index_formula = (4, 16, 16) self.downscale_index_formula = (4, 16, 16)
self.latent_dim = 3 self.latent_dim = 3
self.not_video = True self.not_video = False
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True ddconfig["conv3d"] = True
@ -595,6 +611,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:
@ -897,12 +923,18 @@ class CLIPType(Enum):
OMNIGEN2 = 17 OMNIGEN2 = 17
QWEN_IMAGE = 18 QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19 HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = [] clip_data = []
for p in ckpt_paths: for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True)) sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
if metadata is not None:
quant_metadata = metadata.get("_quantization_metadata", None)
if quant_metadata is not None:
sd["_quantization_metadata"] = quant_metadata
clip_data.append(sd)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
@ -920,6 +952,10 @@ class TEModel(Enum):
QWEN25_7B = 11 QWEN25_7B = 11
BYT5_SMALL_GLYPH = 12 BYT5_SMALL_GLYPH = 12
GEMMA_3_4B = 13 GEMMA_3_4B = 13
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
def detect_te_model(sd): def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -952,6 +988,15 @@ def detect_te_model(sd):
if weight.shape[0] == 512: if weight.shape[0] == 512:
return TEModel.QWEN25_7B return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd: if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
else:
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
return TEModel.LLAMA3_8 return TEModel.LLAMA3_8
return None return None
@ -1066,6 +1111,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
@ -1112,6 +1164,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_IMAGE: elif clip_type == CLIPType.HUNYUAN_IMAGE:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@ -1124,6 +1179,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
parameters = 0 parameters = 0
for c in clip_data: for c in clip_data:
if "_quantization_metadata" in c:
c.pop("_quantization_metadata")
parameters += comfy.utils.calculate_parameters(c) parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@ -1262,7 +1319,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
""" """
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats. Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
@ -1296,7 +1353,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
if model_config is not None: if model_config is not None:
new_sd = sd new_sd = sd
@ -1331,7 +1388,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
else: else:
unet_dtype = dtype unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) if model_config.layer_quant_config is not None:
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
if model_options.get("fp8_optimizations", False): if model_options.get("fp8_optimizations", False):
@ -1347,8 +1407,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
def load_diffusion_model(unet_path, model_options={}): def load_diffusion_model(unet_path, model_options={}):
sd = comfy.utils.load_torch_file(unet_path) sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))

View File

@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False, special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32 return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__() super().__init__()
assert layer in self.LAYERS
if textmodel_json_config is None: if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json") textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
@ -109,13 +108,23 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
operations = model_options.get("custom_operations", None) operations = model_options.get("custom_operations", None)
scaled_fp8 = None scaled_fp8 = None
quantization_metadata = model_options.get("quantization_metadata", None)
if operations is None: if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None) layer_quant_config = None
if scaled_fp8 is not None: if quantization_metadata is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) layer_quant_config = json.loads(quantization_metadata).get("layers", None)
if layer_quant_config is not None:
operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True)
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else: else:
operations = comfy.ops.manual_cast # Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast
self.operations = operations self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations) self.transformer = model_class(config, dtype, device, self.operations)
@ -154,7 +163,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options): def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx) layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
if self.layer == "all": if isinstance(self.layer, list) or self.layer == "all":
pass pass
elif layer_idx is None or abs(layer_idx) > self.num_layers: elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last" self.layer = "last"
@ -256,7 +265,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask_model = attention_mask attention_mask_model = attention_mask
if self.layer == "all": if isinstance(self.layer, list):
intermediate_output = self.layer
elif self.layer == "all":
intermediate_output = "all" intermediate_output = "all"
else: else:
intermediate_output = self.layer_idx intermediate_output = self.layer_idx
@ -460,7 +471,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}): def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None: if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args) self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
@ -468,6 +479,7 @@ class SDTokenizer:
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
self.end_token = None self.end_token = None
self.min_padding = min_padding self.min_padding = min_padding
self.pad_left = pad_left
empty = self.tokenizer('')["input_ids"] empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token self.tokenizer_adds_end_token = has_end_token
@ -522,6 +534,12 @@ class SDTokenizer:
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
return (embed, leftover) return (embed, leftover)
def pad_tokens(self, tokens, amount):
if self.pad_left:
for i in range(amount):
tokens.insert(0, (self.pad_token, 1.0, 0))
else:
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
''' '''
@ -600,7 +618,7 @@ class SDTokenizer:
if self.end_token is not None: if self.end_token is not None:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if self.pad_to_max_length: if self.pad_to_max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) self.pad_tokens(batch, remaining_length)
#start new batch #start new batch
batch = [] batch = []
if self.start_token is not None: if self.start_token is not None:
@ -614,11 +632,11 @@ class SDTokenizer:
if self.end_token is not None: if self.end_token is not None:
batch.append((self.end_token, 1.0, 0)) batch.append((self.end_token, 1.0, 0))
if min_padding is not None: if min_padding is not None:
batch.extend([(self.pad_token, 1.0, 0)] * min_padding) self.pad_tokens(batch, min_padding)
if self.pad_to_max_length and len(batch) < self.max_length: if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) self.pad_tokens(batch, self.max_length - len(batch))
if min_length is not None and len(batch) < min_length: if min_length is not None and len(batch) < min_length:
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) self.pad_tokens(batch, min_length - len(batch))
if not return_word_ids: if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens] batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]

View File

@ -21,6 +21,7 @@ import comfy.text_encoders.ace
import comfy.text_encoders.omnigen2 import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@ -741,6 +742,37 @@ class FluxSchnell(Flux):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device) out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out return out
class Flux2(Flux):
unet_config = {
"image_model": "flux2",
}
sampling_settings = {
"shift": 2.02,
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None # TODO
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class GenmoMochi(supported_models_base.BASE): class GenmoMochi(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "mochi_preview", "image_model": "mochi_preview",
@ -963,7 +995,7 @@ class Lumina2(supported_models_base.BASE):
"shift": 6.0, "shift": 6.0,
} }
memory_usage_factor = 1.2 memory_usage_factor = 1.4
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.Flux latent_format = latent_formats.Flux
@ -982,6 +1014,24 @@ class Lumina2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
class ZImage(Lumina2):
unet_config = {
"image_model": "lumina2",
"dim": 3840,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 3.0,
}
memory_usage_factor = 1.7
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class WAN21_T2V(supported_models_base.BASE): class WAN21_T2V(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "wan2.1", "image_model": "wan2.1",
@ -1374,6 +1424,55 @@ class HunyuanImage21Refiner(HunyuanVideo):
out = model_base.HunyuanImage21Refiner(self, device=device) out = model_base.HunyuanImage21Refiner(self, device=device)
return out return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] class HunyuanVideo15(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
}
sampling_settings = {
"shift": 7.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
unet_config = {
"image_model": "hunyuan_video",
"vision_in_dim": 1152,
"in_channels": 98,
}
sampling_settings = {
"shift": 2.0,
}
memory_usage_factor = 4.0 #TODO
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
latent_format = latent_formats.HunyuanVideo15
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -50,6 +50,7 @@ class BASE:
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None custom_operations = None
scaled_fp8 = None scaled_fp8 = None
layer_quant_config = None # Per-layer quantization configuration for mixed precision
optimizations = {"fp8": False} optimizations = {"fp8": False}
@classmethod @classmethod

View File

@ -1,10 +1,13 @@
from comfy import sd1_clip from comfy import sd1_clip
import comfy.text_encoders.t5 import comfy.text_encoders.t5
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.llama
import comfy.model_management import comfy.model_management
from transformers import T5TokenizerFast from transformers import T5TokenizerFast, LlamaTokenizerFast
import torch import torch
import os import os
import json
import base64
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -68,3 +71,106 @@ def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_
def load_mistral_tokenizer(data):
if torch.is_tensor(data):
data = data.numpy().tobytes()
try:
from transformers.integrations.mistral import MistralConverter
except ModuleNotFoundError:
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
mistral_vocab = json.loads(data)
special_tokens = {}
vocab = {}
max_vocab = mistral_vocab["config"]["default_vocab_size"]
max_vocab -= len(mistral_vocab["special_tokens"])
for w in mistral_vocab["vocab"]:
r = w["rank"]
if r >= max_vocab:
continue
vocab[base64.b64decode(w["token_bytes"])] = r
for w in mistral_vocab["special_tokens"]:
if "token_bytes" in w:
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
else:
special_tokens[w["token_str"]] = w["rank"]
all_special = []
for v in special_tokens:
all_special.append(v)
special_tokens.update(vocab)
vocab = special_tokens
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
class MistralTokenizerClass:
@staticmethod
def from_pretrained(path, **kwargs):
return LlamaTokenizerFast(**kwargs)
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.tekken_data = tokenizer_data.get("tekken_model", None)
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
def state_dict(self):
return {"tekken_model": self.tekken_data}
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Mistral3_24BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
textmodel_json_config = {}
num_layers = model_options.get("num_layers", None)
if num_layers is not None:
textmodel_json_config["num_hidden_layers"] = num_layers
if num_layers < 40:
textmodel_json_config["final_norm"] = False
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class Flux2TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
out = out.movedim(1, 2)
out = out.reshape(out.shape[0], out.shape[1], -1)
return out, pooled, extra
def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False):
class Flux2TEModel_(Flux2TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
if pruned:
model_options = model_options.copy()
model_options["num_layers"] = 30
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Flux2TEModel_

View File

@ -1,6 +1,7 @@
from comfy import sd1_clip from comfy import sd1_clip
import comfy.model_management import comfy.model_management
import comfy.text_encoders.llama import comfy.text_encoders.llama
from .hunyuan_image import HunyuanImageTokenizer
from transformers import LlamaTokenizerFast from transformers import LlamaTokenizerFast
import torch import torch
import os import os
@ -17,6 +18,9 @@ def llama_detect(state_dict, prefix=""):
if scaled_fp8_key in state_dict: if scaled_fp8_key in state_dict:
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
if "_quantization_metadata" in state_dict:
out["llama_quantization_metadata"] = state_dict["_quantization_metadata"]
return out return out
@ -73,6 +77,14 @@ class HunyuanVideoTokenizer:
return {} return {}
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
class HunyuanVideoClipModel(torch.nn.Module): class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__() super().__init__()

View File

@ -32,6 +32,29 @@ class Llama2Config:
q_norm = None q_norm = None
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True
@dataclass
class Mistral3Small24BConfig:
vocab_size: int = 131072
hidden_size: int = 5120
intermediate_size: int = 32768
num_hidden_layers: int = 40
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 1000000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = None
k_norm = None
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_3BConfig: class Qwen25_3BConfig:
@ -53,6 +76,29 @@ class Qwen25_3BConfig:
q_norm = None q_norm = None
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True
@dataclass
class Qwen3_4BConfig:
vocab_size: int = 151936
hidden_size: int = 2560
intermediate_size: int = 9728
num_hidden_layers: int = 36
num_attention_heads: int = 32
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
@ -74,6 +120,7 @@ class Qwen25_7BVLI_Config:
q_norm = None q_norm = None
k_norm = None k_norm = None
rope_scale = None rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Gemma2_2B_Config: class Gemma2_2B_Config:
@ -96,6 +143,7 @@ class Gemma2_2B_Config:
k_norm = None k_norm = None
sliding_attention = None sliding_attention = None
rope_scale = None rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Gemma3_4B_Config: class Gemma3_4B_Config:
@ -118,6 +166,7 @@ class Gemma3_4B_Config:
k_norm = "gemma3" k_norm = "gemma3"
sliding_attention = [False, False, False, False, False, 1024] sliding_attention = [False, False, False, False, False, 1024]
rope_scale = [1.0, 8.0] rope_scale = [1.0, 8.0]
final_norm: bool = True
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
@ -366,7 +415,12 @@ class Llama2_(nn.Module):
transformer(config, index=i, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers) for i in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
if config.final_norm:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
else:
self.norm = None
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
@ -402,8 +456,12 @@ class Llama2_(nn.Module):
intermediate = None intermediate = None
all_intermediate = None all_intermediate = None
only_layers = None
if intermediate_output is not None: if intermediate_output is not None:
if intermediate_output == "all": if isinstance(intermediate_output, list):
all_intermediate = []
only_layers = set(intermediate_output)
elif intermediate_output == "all":
all_intermediate = [] all_intermediate = []
intermediate_output = None intermediate_output = None
elif intermediate_output < 0: elif intermediate_output < 0:
@ -411,7 +469,8 @@ class Llama2_(nn.Module):
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or (i in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
x = layer( x = layer(
x=x, x=x,
attention_mask=mask, attention_mask=mask,
@ -421,14 +480,17 @@ class Llama2_(nn.Module):
if i == intermediate_output: if i == intermediate_output:
intermediate = x.clone() intermediate = x.clone()
x = self.norm(x) if self.norm is not None:
x = self.norm(x)
if all_intermediate is not None: if all_intermediate is not None:
all_intermediate.append(x.unsqueeze(1).clone()) if only_layers is None or ((i + 1) in only_layers):
all_intermediate.append(x.unsqueeze(1).clone())
if all_intermediate is not None: if all_intermediate is not None:
intermediate = torch.cat(all_intermediate, dim=1) intermediate = torch.cat(all_intermediate, dim=1)
if intermediate is not None and final_layer_norm_intermediate: if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
intermediate = self.norm(intermediate) intermediate = self.norm(intermediate)
return x, intermediate return x, intermediate
@ -453,6 +515,15 @@ class Llama2(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Mistral3Small24B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Mistral3Small24BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_3B(BaseLlama, torch.nn.Module): class Qwen25_3B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
@ -462,6 +533,15 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen3_4B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen3_4BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
skip_template = False skip_template = False
if text.startswith('<|im_start|>'): if text.startswith('<|im_start|>'):
skip_template = True skip_template = True
if text.startswith('<|start_header_id|>'): if text.startswith('<|start_header_id|>'):
skip_template = True skip_template = True
if prevent_empty_text and text == '':
text = ' '
if skip_template: if skip_template:
llama_text = text llama_text = text

View File

@ -0,0 +1,48 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Qwen3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class ZImageTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class ZImageTEModel_(ZImageTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return ZImageTEModel_

View File

@ -1109,3 +1109,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
dim=1 dim=1
) )
return out return out
def pack_latents(latents):
latent_shapes = []
tensors = []
for tensor in latents:
latent_shapes.append(tensor.shape)
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
latent = torch.cat(tensors, dim=-1)
return latent, latent_shapes
def unpack_latents(combined_latent, latent_shapes):
if len(latent_shapes) > 1:
output_tensors = []
for shape in latent_shapes:
cut = math.prod(shape[1:])
tens = combined_latent[:, :, :cut]
combined_latent = combined_latent[:, :, cut:]
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
else:
output_tensors = combined_latent
return output_tensors

View File

@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
lora_diff = torch.mm( lora_diff = torch.mm(
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
).reshape(weight.shape) ).reshape(weight.shape)
del mat1, mat2
if dora_scale is not None: if dora_scale is not None:
weight = weight_decompose( weight = weight_decompose(
dora_scale, dora_scale,

View File

@ -8,7 +8,7 @@ import os
import textwrap import textwrap
import threading import threading
from enum import Enum from enum import Enum
from typing import Optional, Type, get_origin, get_args from typing import Optional, Type, get_origin, get_args, get_type_hints
class TypeTracker: class TypeTracker:
@ -220,11 +220,18 @@ class AsyncToSyncConverter:
self._async_instance = async_class(*args, **kwargs) self._async_instance = async_class(*args, **kwargs)
# Handle annotated class attributes (like execution: Execution) # Handle annotated class attributes (like execution: Execution)
# Get all annotations from the class hierarchy # Get all annotations from the class hierarchy and resolve string annotations
all_annotations = {} try:
for base_class in reversed(inspect.getmro(async_class)): # get_type_hints resolves string annotations to actual type objects
if hasattr(base_class, "__annotations__"): # This handles classes using 'from __future__ import annotations'
all_annotations.update(base_class.__annotations__) all_annotations = get_type_hints(async_class)
except Exception:
# Fallback to raw annotations if get_type_hints fails
# (e.g., for undefined forward references)
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
# For each annotated attribute, check if it needs to be created or wrapped # For each annotated attribute, check if it needs to be created or wrapped
for attr_name, attr_type in all_annotations.items(): for attr_name, attr_type in all_annotations.items():
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
"""Extract class attributes that are classes themselves.""" """Extract class attributes that are classes themselves."""
class_attributes = [] class_attributes = []
# Get resolved type hints to handle string annotations
try:
type_hints = get_type_hints(async_class)
except Exception:
type_hints = {}
# Look for class attributes that are classes # Look for class attributes that are classes
for name, attr in sorted(inspect.getmembers(async_class)): for name, attr in sorted(inspect.getmembers(async_class)):
if isinstance(attr, type) and not name.startswith("_"): if isinstance(attr, type) and not name.startswith("_"):
class_attributes.append((name, attr)) class_attributes.append((name, attr))
elif ( elif name in type_hints:
hasattr(async_class, "__annotations__") # Use resolved type hint instead of raw annotation
and name in async_class.__annotations__ annotation = type_hints[name]
):
annotation = async_class.__annotations__[name]
if isinstance(annotation, type): if isinstance(annotation, type):
class_attributes.append((name, annotation)) class_attributes.append((name, annotation))
@ -908,11 +919,15 @@ class AsyncToSyncConverter:
attribute_mappings = {} attribute_mappings = {}
# First check annotations for typed attributes (including from parent classes) # First check annotations for typed attributes (including from parent classes)
# Collect all annotations from the class hierarchy # Resolve string annotations to actual types
all_annotations = {} try:
for base_class in reversed(inspect.getmro(async_class)): all_annotations = get_type_hints(async_class)
if hasattr(base_class, "__annotations__"): except Exception:
all_annotations.update(base_class.__annotations__) # Fallback to raw annotations
all_annotations = {}
for base_class in reversed(inspect.getmro(async_class)):
if hasattr(base_class, "__annotations__"):
all_annotations.update(base_class.__annotations__)
for attr_name, attr_type in sorted(all_annotations.items()): for attr_name, attr_type in sorted(all_annotations.items()):
for class_name, class_type in class_attributes: for class_name, class_type in class_attributes:

View File

@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io from . import _io as io
from . import _ui as ui from . import _ui as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
@ -104,6 +104,8 @@ class Types:
VideoCodec = VideoCodec VideoCodec = VideoCodec
VideoContainer = VideoContainer VideoContainer = VideoContainer
VideoComponents = VideoComponents VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
ComfyAPI = ComfyAPI_latest ComfyAPI = ComfyAPI_latest

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from fractions import Fraction
from typing import Optional, Union, IO from typing import Optional, Union, IO
import io import io
import av import av
@ -72,6 +73,33 @@ class VideoInput(ABC):
frame_count = components.images.shape[0] frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate) return float(frame_count / components.frame_rate)
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video.
Default implementation uses :meth:`get_components`, which may require
loading all frames into memory. File-based implementations should
override this method and use container/stream metadata instead.
Returns:
Total number of frames as an integer.
"""
return int(self.get_components().images.shape[0])
def get_frame_rate(self) -> Fraction:
"""
Returns the frame rate of the video.
Default implementation materializes the video into memory via
`get_components()`. Subclasses that can inspect the underlying
container (e.g. `VideoFromFile`) should override this with a more
efficient implementation.
Returns:
Frame rate as a Fraction.
"""
return self.get_components().frame_rate
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').

View File

@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
raise ValueError(f"Could not determine duration for file '{self.__file}'") raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_frame_count(self) -> int:
"""
Returns the number of frames in the video without materializing them as
torch tensors.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# 1. Prefer the frames field if available
if video_stream.frames and video_stream.frames > 0:
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
if container.duration is not None and video_stream.average_rate:
duration_seconds = float(container.duration / av.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
duration_seconds = float(video_stream.duration * video_stream.time_base)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count == 0:
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
return frame_count
def get_frame_rate(self) -> Fraction:
"""
Returns the average frame rate of the video using container metadata
without decoding all frames.
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
if video_stream.average_rate:
return Fraction(video_stream.average_rate)
# Fallback: estimate from frames + duration if available
if video_stream.frames and container.duration:
duration_seconds = float(container.duration / av.time_base)
if duration_seconds > 0:
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
# Last resort: match get_components_internal default
return Fraction(1)
def get_container_format(self) -> str: def get_container_format(self) -> str:
""" """
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
packet.stream = stream_map[packet.stream] packet.stream = stream_map[packet.stream]
output_container.mux(packet) output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
video_stream = next((s for s in container.streams if s.type == "video"), None)
if video_stream is None:
raise ValueError(f"No video stream found in file '{self.__file}'")
return video_stream
class VideoFromComponents(VideoInput): class VideoFromComponents(VideoInput):
""" """
Class representing video input from tensors. Class representing video input from tensors.

View File

@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
prune_dict, shallow_clone_class) prune_dict, shallow_clone_class)
from comfy_api.latest._resources import Resources, ResourcesLocal from comfy_api.latest._resources import Resources, ResourcesLocal
from comfy_execution.graph_utils import ExecutionBlocker from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference # from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
@ -628,6 +629,10 @@ class UpscaleModel(ComfyTypeIO):
if TYPE_CHECKING: if TYPE_CHECKING:
Type = ImageModelDescriptor Type = ImageModelDescriptor
@comfytype(io_type="LATENT_UPSCALE_MODEL")
class LatentUpscaleModel(ComfyTypeIO):
Type = Any
@comfytype(io_type="AUDIO") @comfytype(io_type="AUDIO")
class Audio(ComfyTypeIO): class Audio(ComfyTypeIO):
class AudioDict(TypedDict): class AudioDict(TypedDict):
@ -656,11 +661,11 @@ class LossMap(ComfyTypeIO):
@comfytype(io_type="VOXEL") @comfytype(io_type="VOXEL")
class Voxel(ComfyTypeIO): class Voxel(ComfyTypeIO):
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 Type = VOXEL
@comfytype(io_type="MESH") @comfytype(io_type="MESH")
class Mesh(ComfyTypeIO): class Mesh(ComfyTypeIO):
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 Type = MESH
@comfytype(io_type="HOOKS") @comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO): class Hooks(ComfyTypeIO):

View File

@ -1,8 +1,11 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH
__all__ = [ __all__ = [
# Utility Types # Utility Types
"VideoContainer", "VideoContainer",
"VideoCodec", "VideoCodec",
"VideoComponents", "VideoComponents",
"VOXEL",
"MESH",
] ]

View File

@ -0,0 +1,12 @@
import torch
class VOXEL:
def __init__(self, data: torch.Tensor):
self.data = data
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
self.vertices = vertices
self.faces = faces

View File

@ -1,718 +0,0 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
import os
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
from comfy_api.input.video_types import VideoInput
from comfy_api.input.basic_types import AudioInput
from comfy_api_nodes.apis.client import (
ApiClient,
ApiEndpoint,
HttpMethod,
SynchronousOperation,
UploadRequest,
UploadResponse,
)
from server import PromptServer
from comfy.cli_args import args
import numpy as np
from PIL import Image
import torch
import math
import base64
import uuid
from io import BytesIO
import av
async def download_url_to_video_output(
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.
Args:
video_url: The URL of the video to download.
Returns:
A Comfy node `VIDEO` output.
"""
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
raise ValueError(error_msg)
return VideoFromFile(video_io)
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def validate_aspect_ratio(
aspect_ratio: str,
minimum_ratio: float,
maximum_ratio: float,
minimum_ratio_str: str,
maximum_ratio_str: str,
) -> float:
"""Validates and casts an aspect ratio string to a float.
Args:
aspect_ratio: The aspect ratio string to validate.
minimum_ratio: The minimum aspect ratio.
maximum_ratio: The maximum aspect ratio.
minimum_ratio_str: The minimum aspect ratio string.
maximum_ratio_str: The maximum aspect ratio string.
Returns:
The validated and cast aspect ratio.
Raises:
Exception: If the aspect ratio is not valid.
"""
# get ratio values
numbers = aspect_ratio.split(":")
if len(numbers) != 2:
raise TypeError(
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
)
try:
numerator = int(numbers[0])
denominator = int(numbers[1])
except ValueError as exc:
raise TypeError(
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
) from exc
calculated_ratio = numerator / denominator
# if not close to minimum and maximum, check bounds
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
calculated_ratio, maximum_ratio
):
if calculated_ratio < minimum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
if calculated_ratio > maximum_ratio:
raise TypeError(
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
)
return aspect_ratio
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
async def download_url_to_bytesio(
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.
Args:
url: The URL to download.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
BytesIO object containing the downloaded content.
"""
headers = {}
if url.startswith("/proxy/"):
url = str(args.comfy_api_base).rstrip("/") + url
auth_token = auth_kwargs.get("auth_token")
comfy_api_key = auth_kwargs.get("comfy_api_key")
if auth_token:
headers["Authorization"] = f"Bearer {auth_token}"
elif comfy_api_key:
headers["X-API-KEY"] = comfy_api_key
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url, headers=headers) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response_content))
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(
input_tensor.unsqueeze(0), total_pixels=total_pixels
).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = io.BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = (
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
)
return img_binary
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
Args:
file_bytes_io: BytesIO object containing the file data.
filename: The filename of the file.
upload_mime_type: MIME type of the file.
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded file.
"""
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
Uses the specified container and codec for saving the video before upload.
Args:
video: VideoInput object (Comfy VIDEO type).
auth_kwargs: Optional authentication token(s).
container: The video container format to use (default: MP4).
codec: The video codec to use (default: H264).
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
Returns:
The download URL for the uploaded video file.
"""
if max_duration is not None:
try:
actual_duration = video.duration_seconds
if actual_duration is not None and actual_duration > max_duration:
raise ValueError(
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
)
except Exception as e:
logging.error("Error getting video duration: %s", str(e))
raise ValueError(f"Could not verify video duration from source: {e}") from e
upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = io.BytesIO()
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = io.BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
filename: str = "uploaded_audio.mp4",
) -> str:
"""
Uploads a single audio input to ComfyUI API and returns its download URL.
Encodes the raw waveform into the specified format before uploading.
Args:
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
auth_kwargs: Optional authentication token(s).
Returns:
The download URL for the uploaded audio file.
"""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2 ** 15)
elif wav.dtype == torch.int32:
return wav.float() / (2 ** 31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(io.BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
Args:
image: Input torch.Tensor image.
max_images: Maximum number of images to upload.
auth_kwargs: Optional authentication token(s).
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def get_size(path_or_object: Union[str, io.BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
def validate_container_format_is_mp4(video: VideoInput) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")

View File

@ -1,17 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel
from . import PixverseDto
class ResponseData(BaseModel):
ErrCode: Optional[int] = None
ErrMsg: Optional[str] = None
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None

View File

@ -1,57 +0,0 @@
# generated by datamodel-codegen:
# filename: filtered-openapi.yaml
# timestamp: 2025-04-29T23:44:54+00:00
from __future__ import annotations
from typing import Optional
from pydantic import BaseModel, Field
class V2OpenAPII2VResp(BaseModel):
video_id: Optional[int] = Field(None, description='Video_id')
class V2OpenAPIT2VReq(BaseModel):
aspect_ratio: str = Field(
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
)
duration: int = Field(
...,
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
examples=[5],
)
model: str = Field(
..., description='Model version (only supports v3.5)', examples=['v3.5']
)
motion_mode: Optional[str] = Field(
'normal',
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
examples=['normal'],
)
negative_prompt: Optional[str] = Field(
None, description='Negative prompt\n', max_length=2048
)
prompt: str = Field(..., description='Prompt', max_length=2048)
quality: str = Field(
...,
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
examples=['540p'],
)
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
style: Optional[str] = Field(
None,
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
examples=['anime'],
)
template_id: Optional[int] = Field(
None,
description='Template ID (template_id must be activated before use)',
examples=[302325299692608],
)
water_mark: Optional[bool] = Field(
False,
description='Watermark (true: add watermark, false: no watermark)',
examples=[False],
)

View File

@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
class BFLFluxCannyImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxDepthImageRequest(BaseModel):
prompt: str = Field(..., description='Text prompt for image generation')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
class BFLFluxProGenerateRequest(BaseModel): class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.') prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field( prompt_upsampling: Optional[bool] = Field(
@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
# ) # )
class Flux2ProGenerateRequest(BaseModel):
prompt: str = Field(...)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
seed: int | None = Field(None)
prompt_upsampling: bool | None = Field(None)
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
class BFLFluxKontextProGenerateRequest(BaseModel): class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.') prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format') input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
class BFLFluxProGenerateResponse(BaseModel): class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description='The unique identifier for the generation task.') id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description='URL to poll for the generation result.') polling_url: str = Field(..., description="URL to poll for the generation result.")
cost: float | None = Field(None, description="Price in cents")
class BFLStatus(str, Enum): class BFLStatus(str, Enum):
@ -160,15 +146,8 @@ class BFLStatus(str, Enum):
error = "Error" error = "Error"
class BFLFluxProStatusResponse(BaseModel): class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.") id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.") status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field( result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
None, description="The result of the task (null if not completed)." progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
)
progress: confloat(ge=0.0, le=1.0) = Field(
..., description="The progress of the task (0.0 to 1.0)."
)
details: Optional[Dict[str, Any]] = Field(
None, description="Additional details about the task (null if not available)."
)

View File

@ -1,981 +0,0 @@
"""
API Client Framework for api.comfy.org.
This module provides a flexible framework for making API requests from ComfyUI nodes.
It supports both synchronous and asynchronous API operations with proper type validation.
Key Components:
--------------
1. ApiClient - Handles HTTP requests with authentication and error handling
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
3. ApiOperation - Executes a single synchronous API operation
Usage Examples:
--------------
# Example 1: Synchronous API Operation
# ------------------------------------
# For a simple API call that returns the result immediately:
# 1. Create the API client
api_client = ApiClient(
base_url="https://api.example.com",
auth_token="your_auth_token_here",
comfy_api_key="your_comfy_api_key_here",
timeout=30.0,
verify_ssl=True
)
# 2. Define the endpoint
user_info_endpoint = ApiEndpoint(
path="/v1/users/me",
method=HttpMethod.GET,
request_model=EmptyRequest, # No request body needed
response_model=UserProfile, # Pydantic model for the response
query_params=None
)
# 3. Create the request object
request = EmptyRequest()
# 4. Create and execute the operation
operation = ApiOperation(
endpoint=user_info_endpoint,
request=request
)
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
# Example 2: Asynchronous API Operation with Polling
# -------------------------------------------------
# For an API that starts a task and requires polling for completion:
# 1. Define the endpoints (initial request and polling)
generate_image_endpoint = ApiEndpoint(
path="/v1/images/generate",
method=HttpMethod.POST,
request_model=ImageGenerationRequest,
response_model=TaskCreatedResponse,
query_params=None
)
check_task_endpoint = ApiEndpoint(
path="/v1/tasks/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=ImageGenerationResult,
query_params=None
)
# 2. Create the request object
request = ImageGenerationRequest(
prompt="a beautiful sunset over mountains",
width=1024,
height=1024,
num_images=1
)
# 3. Create and execute the polling operation
operation = PollingOperation(
initial_endpoint=generate_image_endpoint,
initial_request=request,
poll_endpoint=check_task_endpoint,
task_id_field="task_id",
status_field="status",
completed_statuses=["completed"],
failed_statuses=["failed", "error"]
)
# This will make the initial request and then poll until completion
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
"""
from __future__ import annotations
import aiohttp
import asyncio
import logging
import io
import os
import socket
from aiohttp.client_exceptions import ClientError, ClientResponseError
from typing import Type, Optional, Any, TypeVar, Generic, Callable
from enum import Enum
import json
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
pass
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
class UploadResponse(BaseModel):
download_url: str = Field(..., description="URL to GET uploaded file")
upload_url: str = Field(..., description="URL to PUT file to upload")
class HttpMethod(str, Enum):
GET = "GET"
POST = "POST"
PUT = "PUT"
DELETE = "DELETE"
PATCH = "PATCH"
class ApiClient:
"""
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[tuple[int, ...]] = None,
session: Optional[aiohttp.ClientSession] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
self._session: Optional[aiohttp.ClientSession] = session
self._owns_session = session is None # Track if we have to close it
@staticmethod
def _generate_operation_id(path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
@staticmethod
def _create_json_payload_args(
data: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
return {
"json": data,
"headers": headers,
}
def _create_form_data_args(
self,
data: dict[str, Any] | None,
files: dict[str, Any] | None,
headers: Optional[dict[str, str]] = None,
multipart_parser: Callable | None = None,
) -> dict[str, Any]:
if headers and "Content-Type" in headers:
del headers["Content-Type"]
if multipart_parser and data:
data = multipart_parser(data)
if isinstance(data, aiohttp.FormData):
form = data # If the parser already returned a FormData, pass it through
else:
form = aiohttp.FormData(default_to_multipart=True)
if data: # regular text fields
for k, v in data.items():
if v is None:
continue # aiohttp fails to serialize "None" values
# aiohttp expects strings or bytes; convert enums etc.
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if files:
file_iter = files if isinstance(files, list) else files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue # aiohttp fails to serialize "None" values
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
if isinstance(file_obj, tuple):
filename, file_value, content_type = self._unpack_tuple(file_obj)
else:
file_value = file_obj
filename = getattr(file_obj, "name", field_name)
content_type = "application/octet-stream"
form.add_field(
name=field_name,
value=file_value,
filename=filename,
content_type=content_type,
)
return {"data": form, "headers": headers or {}}
@staticmethod
def _create_urlencoded_form_data_args(
data: dict[str, Any],
headers: Optional[dict[str, str]] = None,
) -> dict[str, Any]:
headers = headers or {}
headers["Content-Type"] = "application/x-www-form-urlencoded"
return {
"data": data,
"headers": headers,
}
def get_headers(self) -> dict[str, str]:
"""Get headers for API requests, including authentication if available"""
headers = {"Content-Type": "application/json", "Accept": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
elif self.comfy_api_key:
headers["X-API-KEY"] = self.comfy_api_key
return headers
async def _check_connectivity(self, target_url: str) -> dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
try:
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
results["internet_accessible"] = resp.status < 500
except (ClientError, asyncio.TimeoutError, socket.gaierror):
results["is_local_issue"] = True
return results # cannot reach the internet early exit
# Now check API health endpoint
parsed = urlparse(target_url)
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
try:
async with session.get(health_url, ssl=self.verify_ssl) as resp:
results["api_accessible"] = resp.status < 500
except ClientError:
pass # leave as False
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
return results
async def request(
self,
method: str,
path: str,
params: Optional[dict[str, Any]] = None,
data: Optional[dict[str, Any]] = None,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
headers: Optional[dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
retry_count: int = 0, # Used internally for tracking retries
) -> dict[str, Any]:
"""
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
path: API endpoint path (will be joined with base_url)
params: Query parameters
data: body data
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
# Build full URL and merge headers
relative_path = path.lstrip("/")
url = urljoin(self.base_url, relative_path)
self._check_auth(self.auth_token, self.comfy_api_key)
request_headers = self.get_headers()
if headers:
request_headers.update(headers)
if files:
request_headers.pop("Content-Type", None)
if params:
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
logging.debug("[DEBUG] Request Headers: %s", request_headers)
logging.debug("[DEBUG] Files: %s", files)
logging.debug("[DEBUG] Params: %s", params)
logging.debug("[DEBUG] Data: %s", data)
if content_type == "application/x-www-form-urlencoded":
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
elif content_type == "multipart/form-data":
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]",
)
session = await self._get_session()
try:
async with session.request(
method,
url,
params=params,
ssl=self.verify_ssl,
**payload_args,
) as resp:
if resp.status >= 400:
try:
error_data = await resp.json()
except (aiohttp.ContentTypeError, json.JSONDecodeError):
error_data = await resp.text()
return await self._handle_http_error(
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
operation_id,
method,
url,
params,
data,
files,
headers,
content_type,
multipart_parser,
retry_count=retry_count,
response_content=error_data,
)
# Success parse JSON (safely) and log
try:
payload = await resp.json()
response_content_to_log = payload
except (aiohttp.ContentTypeError, json.JSONDecodeError):
payload = {}
response_content_to_log = await resp.text()
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
return payload
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
# Treat as *connection* problem optionally retry, else escalate
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
self.max_retries, str(e))
await asyncio.sleep(delay)
return await self.request(
method,
path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# One final connectivity check for diagnostics
connectivity = await self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
@staticmethod
def _check_auth(auth_token, comfy_api_key):
"""Verify that an auth token is present or comfy_api_key is present"""
if auth_token is None and comfy_api_key is None:
raise Exception("Unauthorized: Please login first to use this node.")
return auth_token or comfy_api_key
@staticmethod
async def upload_file(
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> aiohttp.ClientResponse:
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers: dict[str, str] = {}
skip_auto_headers: set[str] = set()
if content_type:
headers["Content-Type"] = content_type
else:
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
skip_auto_headers.add("Content-Type")
# Extract file bytes
if isinstance(file, io.BytesIO):
file.seek(0)
data = file.read()
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
else:
raise ValueError("File must be BytesIO or str path")
parsed = urlparse(upload_url)
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data {len(data)} bytes]",
)
delay = retry_delay
for attempt in range(max_retries + 1):
try:
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.put(
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
) as resp:
resp.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content="File uploaded successfully.",
)
return resp
except (ClientError, asyncio.TimeoutError) as e:
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
response_status_code=e.status if hasattr(e, "status") else None,
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
response_content=None,
error_message=f"{type(e).__name__}: {str(e)}",
)
if attempt < max_retries:
logging.warning(
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
)
await asyncio.sleep(delay)
delay *= retry_backoff_factor
else:
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
async def _handle_http_error(
self,
exc: ClientResponseError,
operation_id: str,
*req_meta,
retry_count: int,
response_content: dict | str = "",
) -> dict[str, Any]:
status_code = exc.status
if status_code == 401:
user_friendly = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_friendly = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_friendly = "Rate Limit Exceeded: Please try again later."
else:
if isinstance(response_content, dict):
if "error" in response_content and "message" in response_content["error"]:
user_friendly = f"API Error: {response_content['error']['message']}"
if "type" in response_content["error"]:
user_friendly += f" (Type: {response_content['error']['type']})"
else: # Handle cases where error is just a JSON dict with unknown format
user_friendly = f"API Error: {json.dumps(response_content)}"
else:
if len(response_content) < 200: # Arbitrary limit for display
user_friendly = f"API Error (raw): {response_content}"
else:
user_friendly = f"API Error (raw, status {response_content})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=req_meta[0],
request_url=req_meta[1],
response_status_code=exc.status,
response_headers=dict(req_meta[5]) if req_meta[5] else None,
response_content=response_content,
error_message=f"HTTP Error {exc.status}",
)
logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code)
if response_content:
logging.debug("[DEBUG] Response content: %s", response_content)
# Retry if eligible
if status_code in self.retry_status_codes and retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
"HTTP error %s. Retrying in %.2fs (%s/%s)",
status_code,
delay,
retry_count + 1,
self.max_retries,
)
await asyncio.sleep(delay)
return await self.request(
req_meta[0], # method
req_meta[1].replace(self.base_url, ""), # path
params=req_meta[2],
data=req_meta[3],
files=req_meta[4],
headers=req_meta[5],
content_type=req_meta[6],
multipart_parser=req_meta[7],
retry_count=retry_count + 1,
)
raise Exception(user_friendly) from exc
@staticmethod
def _unpack_tuple(t):
"""Helper to normalise (filename, file, content_type) tuples."""
if len(t) == 3:
return t
elif len(t) == 2:
return t[0], t[1], "application/octet-stream"
else:
raise ValueError("files tuple must be (filename, file[, content_type])")
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(timeout=timeout)
self._owns_session = True
return self._session
async def close(self) -> None:
if self._owns_session and self._session and not self._session.closed:
await self._session.close()
async def __aenter__(self) -> "ApiClient":
"""Allow usage as asynccontextmanager ensures clean teardown"""
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
class ApiEndpoint(Generic[T, R]):
"""Defines an API endpoint with its request and response types"""
def __init__(
self,
path: str,
method: HttpMethod,
request_model: Type[T],
response_model: Type[R],
query_params: Optional[dict[str, Any]] = None,
):
"""Initialize an API endpoint definition.
Args:
path: The URL path for this endpoint, can include placeholders like {id}
method: The HTTP method to use (GET, POST, etc.)
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
query_params: Optional dictionary of query parameters to include in the request
"""
self.path = path
self.method = method
self.request_model = request_model
self.response_model = response_model
self.query_params = query_params or {}
class SynchronousOperation(Generic[T, R]):
"""Represents a single synchronous API operation."""
def __init__(
self,
endpoint: ApiEndpoint[T, R],
request: T,
files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[dict[str, str]] = None,
timeout: float = 7200.0,
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
) -> None:
self.endpoint = endpoint
self.request = request
self.files = files
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.timeout = timeout
self.verify_ssl = verify_ssl
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
request_dict: Optional[dict[str, Any]]
if isinstance(self.request, EmptyRequest):
request_dict = None
else:
request_dict = self.request.model_dump(exclude_none=True)
for k, v in list(request_dict.items()):
if isinstance(v, Enum):
request_dict[k] = v.value
logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path)
logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2))
logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params)
response_json = await client.request(
self.endpoint.method.value,
self.endpoint.path,
params=self.endpoint.query_params,
data=request_dict,
files=self.files,
content_type=self.content_type,
multipart_parser=self.multipart_parser,
)
logging.debug("=" * 50)
logging.debug("[DEBUG] RESPONSE DETAILS:")
logging.debug("[DEBUG] Status Code: 200 (Success)")
logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2))
logging.debug("=" * 50)
parsed_response = self.endpoint.response_model.model_validate(response_json)
logging.debug("[DEBUG] Parsed Response: %s", parsed_response)
return parsed_response
finally:
if owns_client:
await client.close()
class TaskStatus(str, Enum):
"""Enum for task status values"""
COMPLETED = "completed"
FAILED = "failed"
PENDING = "pending"
class PollingOperation(Generic[T, R]):
"""Represents an asynchronous API operation that requires polling for completion."""
def __init__(
self,
poll_endpoint: ApiEndpoint[EmptyRequest, R],
completed_statuses: list[str],
failed_statuses: list[str],
*,
status_extractor: Callable[[R], Optional[str]],
progress_extractor: Callable[[R], Optional[float]] | None = None,
result_url_extractor: Callable[[R], Optional[str]] | None = None,
price_extractor: Callable[[R], Optional[float]] | None = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[dict[str, str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
) -> None:
self.poll_endpoint = poll_endpoint
self.request = request
self.api_base: str = api_base or args.comfy_api_base
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
if auth_kwargs is not None:
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.price_extractor = price_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
self.final_response: Optional[R] = None
self.extracted_price: Optional[float] = None
async def execute(self, client: Optional[ApiClient] = None) -> R:
owns_client = client is None
if owns_client:
client = ApiClient(
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
try:
return await self._poll_until_complete(client)
finally:
if owns_client:
await client.close()
def _display_text_on_node(self, text: str):
if not self.node_id:
return
if self.extracted_price is not None:
text = f"Price: ${self.extracted_price}\n{text}"
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int | float):
if not self.node_id:
return
if self.estimated_duration is not None:
remaining = max(0, int(self.estimated_duration) - time_completed)
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
else:
message = f"Task in progress: {time_completed}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
try:
status = self.status_extractor(response)
if status in self.completed_statuses:
return TaskStatus.COMPLETED
if status in self.failed_statuses:
return TaskStatus.FAILED
return TaskStatus.PENDING
except Exception as e:
logging.error("Error extracting status: %s", e)
return TaskStatus.PENDING
async def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
status = TaskStatus.PENDING
for poll_count in range(1, self.max_poll_attempts + 1):
try:
logging.debug("[DEBUG] Polling attempt #%s", poll_count)
request_dict = None if self.request is None else self.request.model_dump(exclude_none=True)
if poll_count == 1:
logging.debug(
"[DEBUG] Poll Request: %s %s",
self.poll_endpoint.method.value,
self.poll_endpoint.path,
)
logging.debug(
"[DEBUG] Poll Request Data: %s",
json.dumps(request_dict, indent=2) if request_dict else "None",
)
# Query task status
resp = await client.request(
self.poll_endpoint.method.value,
self.poll_endpoint.path,
params=self.poll_endpoint.query_params,
data=request_dict,
)
consecutive_errors = 0 # reset on success
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug("[DEBUG] Task Status: %s", status)
# If progress extractor is provided, extract progress
if self.progress_extractor:
new_progress = self.progress_extractor(response_obj)
if new_progress is not None:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if self.price_extractor:
price = self.price_extractor(response_obj)
if price is not None:
self.extracted_price = price
if status == TaskStatus.COMPLETED:
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
logging.debug("[DEBUG] %s", message)
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
return self.final_response
if status == TaskStatus.FAILED:
message = f"Task failed: {json.dumps(resp)}"
logging.error("[DEBUG] %s", message)
raise Exception(message)
logging.debug("[DEBUG] Task still pending, continuing to poll...")
# Task pending wait
for i in range(int(self.poll_interval)):
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
await asyncio.sleep(1)
except (LocalNetworkError, ApiServerError, NetworkError) as e:
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
) from e
logging.warning(
"Network error (%s/%s): %s",
consecutive_errors,
max_consecutive_errors,
str(e),
)
await asyncio.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error("[DEBUG] Polling error: %s", str(e))
logging.warning(
"Error during polling (attempt %s/%s): %s. Will retry in %s seconds.",
poll_count,
self.max_poll_attempts,
str(e),
self.poll_interval,
)
await asyncio.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
"The operation may still be running on the server but is taking longer than expected."
)

View File

@ -1,22 +1,230 @@
from typing import Optional from datetime import date
from enum import Enum
from typing import Any
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata from pydantic import BaseModel, Field
from pydantic import BaseModel
class GeminiSafetyCategory(str, Enum):
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
class GeminiSafetyThreshold(str, Enum):
OFF = "OFF"
BLOCK_NONE = "BLOCK_NONE"
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
class GeminiSafetySetting(BaseModel):
category: GeminiSafetyCategory
threshold: GeminiSafetyThreshold
class GeminiRole(str, Enum):
user = "user"
model = "model"
class GeminiMimeType(str, Enum):
application_pdf = "application/pdf"
audio_mpeg = "audio/mpeg"
audio_mp3 = "audio/mp3"
audio_wav = "audio/wav"
image_png = "image/png"
image_jpeg = "image/jpeg"
image_webp = "image/webp"
text_plain = "text/plain"
video_mov = "video/mov"
video_mpeg = "video/mpeg"
video_mp4 = "video/mp4"
video_mpg = "video/mpg"
video_avi = "video/avi"
video_wmv = "video/wmv"
video_mpegps = "video/mpegps"
video_flv = "video/flv"
class GeminiInlineData(BaseModel):
data: str | None = Field(
None,
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
)
mimeType: GeminiMimeType | None = Field(None)
class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None)
text: str | None = Field(None)
class GeminiTextPart(BaseModel):
text: str | None = Field(None)
class GeminiContent(BaseModel):
parts: list[GeminiPart] = Field([])
role: GeminiRole = Field(..., examples=["user"])
class GeminiSystemInstructionContent(BaseModel):
parts: list[GeminiTextPart] = Field(
...,
description="A list of ordered parts that make up a single message. "
"Different parts may have different IANA MIME types.",
)
role: GeminiRole = Field(
...,
description="The identity of the entity that creates the message. "
"The following values are supported: "
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
"model: This indicates that the message is generated by the model. "
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
"For non-multi-turn conversations, this field can be left blank or unset.",
)
class GeminiFunctionDeclaration(BaseModel):
description: str | None = Field(None)
name: str = Field(...)
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
class GeminiTool(BaseModel):
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
class GeminiOffset(BaseModel):
nanos: int | None = Field(None, ge=0, le=999999999)
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
class GeminiVideoMetadata(BaseModel):
endOffset: GeminiOffset | None = Field(None)
startOffset: GeminiOffset | None = Field(None)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
class GeminiImageConfig(BaseModel): class GeminiImageConfig(BaseModel):
aspectRatio: Optional[str] = None aspectRatio: str | None = Field(None)
imageSize: str | None = Field(None)
class GeminiImageGenerationConfig(GeminiGenerationConfig): class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: Optional[list[str]] = None responseModalities: list[str] | None = Field(None)
imageConfig: Optional[GeminiImageConfig] = None imageConfig: GeminiImageConfig | None = Field(None)
class GeminiImageGenerateContentRequest(BaseModel): class GeminiImageGenerateContentRequest(BaseModel):
contents: list[GeminiContent] contents: list[GeminiContent] = Field(...)
generationConfig: Optional[GeminiImageGenerationConfig] = None generationConfig: GeminiImageGenerationConfig | None = Field(None)
safetySettings: Optional[list[GeminiSafetySetting]] = None safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: Optional[GeminiSystemInstructionContent] = None systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: Optional[list[GeminiTool]] = None tools: list[GeminiTool] | None = Field(None)
videoMetadata: Optional[GeminiVideoMetadata] = None videoMetadata: GeminiVideoMetadata | None = Field(None)
class GeminiGenerateContentRequest(BaseModel):
contents: list[GeminiContent] = Field(...)
generationConfig: GeminiGenerationConfig | None = Field(None)
safetySettings: list[GeminiSafetySetting] | None = Field(None)
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
tools: list[GeminiTool] | None = Field(None)
videoMetadata: GeminiVideoMetadata | None = Field(None)
class Modality(str, Enum):
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
TEXT = "TEXT"
IMAGE = "IMAGE"
VIDEO = "VIDEO"
AUDIO = "AUDIO"
DOCUMENT = "DOCUMENT"
class ModalityTokenCount(BaseModel):
modality: Modality | None = None
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
class Probability(str, Enum):
NEGLIGIBLE = "NEGLIGIBLE"
LOW = "LOW"
MEDIUM = "MEDIUM"
HIGH = "HIGH"
UNKNOWN = "UNKNOWN"
class GeminiSafetyRating(BaseModel):
category: GeminiSafetyCategory | None = None
probability: Probability | None = Field(
None,
description="The probability that the content violates the specified safety category",
)
class GeminiCitation(BaseModel):
authors: list[str] | None = None
endIndex: int | None = None
license: str | None = None
publicationDate: date | None = None
startIndex: int | None = None
title: str | None = None
uri: str | None = None
class GeminiCitationMetadata(BaseModel):
citations: list[GeminiCitation] | None = None
class GeminiCandidate(BaseModel):
citationMetadata: GeminiCitationMetadata | None = None
content: GeminiContent | None = None
finishReason: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiPromptFeedback(BaseModel):
blockReason: str | None = None
blockReasonMessage: str | None = None
safetyRatings: list[GeminiSafetyRating] | None = None
class GeminiUsageMetadata(BaseModel):
cachedContentTokenCount: int | None = Field(
None,
description="Output only. Number of tokens in the cached part in the input (the cached content).",
)
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of candidate tokens by modality."
)
promptTokenCount: int | None = Field(
None,
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
)
promptTokensDetails: list[ModalityTokenCount] | None = Field(
None, description="Breakdown of prompt tokens by modality."
)
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
class GeminiGenerateContentResponse(BaseModel):
candidates: list[GeminiCandidate] | None = Field(None)
promptFeedback: GeminiPromptFeedback | None = Field(None)
usageMetadata: GeminiUsageMetadata | None = Field(None)
modelVersion: str | None = Field(None)

View File

@ -0,0 +1,120 @@
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class MinimaxBaseResponse(BaseModel):
status_code: int = Field(
...,
description='Status code. 0 indicates success, other values indicate errors.',
)
status_msg: str = Field(
..., description='Specific error details or success message.'
)
class File(BaseModel):
bytes: Optional[int] = Field(None, description='File size in bytes')
created_at: Optional[int] = Field(
None, description='Unix timestamp when the file was created, in seconds'
)
download_url: Optional[str] = Field(
None, description='The URL to download the video'
)
backup_download_url: Optional[str] = Field(
None, description='The backup URL to download the video'
)
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
filename: Optional[str] = Field(None, description='The name of the file')
purpose: Optional[str] = Field(None, description='The purpose of using the file')
class MinimaxFileRetrieveResponse(BaseModel):
base_resp: MinimaxBaseResponse
file: File
class MiniMaxModel(str, Enum):
T2V_01_Director = 'T2V-01-Director'
I2V_01_Director = 'I2V-01-Director'
S2V_01 = 'S2V-01'
I2V_01 = 'I2V-01'
I2V_01_live = 'I2V-01-live'
T2V_01 = 'T2V-01'
Hailuo_02 = 'MiniMax-Hailuo-02'
class Status6(str, Enum):
Queueing = 'Queueing'
Preparing = 'Preparing'
Processing = 'Processing'
Success = 'Success'
Fail = 'Fail'
class MinimaxTaskResultResponse(BaseModel):
base_resp: MinimaxBaseResponse
file_id: Optional[str] = Field(
None,
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
)
status: Status6 = Field(
...,
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
)
task_id: str = Field(..., description='The task ID being queried.')
class SubjectReferenceItem(BaseModel):
image: Optional[str] = Field(
None, description='URL or base64 encoding of the subject reference image.'
)
mask: Optional[str] = Field(
None,
description='URL or base64 encoding of the mask for the subject reference image.',
)
class MinimaxVideoGenerationRequest(BaseModel):
callback_url: Optional[str] = Field(
None,
description='Optional. URL to receive real-time status updates about the video generation task.',
)
first_frame_image: Optional[str] = Field(
None,
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
)
model: MiniMaxModel = Field(
...,
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
)
prompt: Optional[str] = Field(
None,
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
max_length=2000,
)
prompt_optimizer: Optional[bool] = Field(
True,
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
)
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
None,
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
)
duration: Optional[int] = Field(
None,
description="The length of the output video in seconds."
)
resolution: Optional[str] = Field(
None,
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
)
class MinimaxVideoGenerationResponse(BaseModel):
base_resp: MinimaxBaseResponse
task_id: str = Field(
..., description='The task ID for the asynchronous video generation task.'
)

View File

@ -0,0 +1,133 @@
from typing import Optional, Union
from pydantic import BaseModel, Field
class ImageEnhanceRequest(BaseModel):
model: str = Field("Reimagine")
output_format: str = Field("jpeg")
subject_detection: str = Field("All")
face_enhancement: bool = Field(True)
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
source_url: str = Field(...)
output_width: Optional[int] = Field(None)
output_height: Optional[int] = Field(None)
crop_to_fill: bool = Field(False)
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
face_preservation: str = Field("true", description="To preserve the identity of characters")
color_preservation: str = Field("true", description="To preserve the original color")
class ImageAsyncTaskResponse(BaseModel):
process_id: str = Field(...)
class ImageStatusResponse(BaseModel):
process_id: str = Field(...)
status: str = Field(...)
progress: Optional[int] = Field(None)
credits: int = Field(...)
class ImageDownloadResponse(BaseModel):
download_url: str = Field(...)
expiry: int = Field(...)
class Resolution(BaseModel):
width: int = Field(...)
height: int = Field(...)
class CreateCreateVideoRequestSource(BaseModel):
container: str = Field(...)
size: int = Field(..., description="Size of the video file in bytes")
duration: int = Field(..., description="Duration of the video file in seconds")
frameCount: int = Field(..., description="Total number of frames in the video")
frameRate: int = Field(...)
resolution: Resolution = Field(...)
class VideoFrameInterpolationFilter(BaseModel):
model: str = Field(...)
slowmo: Optional[int] = Field(None)
fps: int = Field(...)
duplicate: bool = Field(...)
duplicate_threshold: float = Field(...)
class VideoEnhancementFilter(BaseModel):
model: str = Field(...)
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
compression: Optional[float] = Field(None, description="Strength of compression recovery")
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
noise: Optional[float] = Field(None, description="Amount of noise reduction")
halo: Optional[float] = Field(None, description="Amount of halo reduction")
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
class OutputInformationVideo(BaseModel):
resolution: Resolution = Field(...)
frameRate: int = Field(...)
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
audioTransfer: str = Field(..., description="Copy, Convert, None")
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
class Overrides(BaseModel):
isPaidDiffusion: bool = Field(True)
class CreateVideoRequest(BaseModel):
source: CreateCreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
class CreateVideoResponse(BaseModel):
requestId: str = Field(...)
class VideoAcceptResponse(BaseModel):
uploadId: str = Field(...)
urls: list[str] = Field(...)
class VideoCompleteUploadRequestPart(BaseModel):
partNum: int = Field(...)
eTag: str = Field(...)
class VideoCompleteUploadRequest(BaseModel):
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
class VideoCompleteUploadResponse(BaseModel):
message: str = Field(..., description="Confirmation message")
class VideoStatusResponseEstimates(BaseModel):
cost: list[int] = Field(...)
class VideoStatusResponseDownloadUrl(BaseModel):
url: str = Field(...)
class VideoStatusResponse(BaseModel):
status: str = Field(...)
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
progress: Optional[float] = Field(None)
message: Optional[str] = Field("")
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)

View File

@ -1,13 +1,20 @@
from __future__ import annotations from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum from enum import Enum
from typing import Optional, List, Dict, Any, Union from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
class TripoStyle(str, Enum): class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon" PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom" ANIMAL_VENOM = "animal:venom"

View File

@ -0,0 +1,111 @@
from typing import Optional, Union
from enum import Enum
from pydantic import BaseModel, Field
class Image2(BaseModel):
bytesBase64Encoded: str
gcsUri: Optional[str] = None
mimeType: Optional[str] = None
class Image3(BaseModel):
bytesBase64Encoded: Optional[str] = None
gcsUri: str
mimeType: Optional[str] = None
class Instance1(BaseModel):
image: Optional[Union[Image2, Image3]] = Field(
None, description='Optional image to guide video generation'
)
prompt: str = Field(..., description='Text description of the video')
class PersonGeneration1(str, Enum):
ALLOW = 'ALLOW'
BLOCK = 'BLOCK'
class Parameters1(BaseModel):
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
durationSeconds: Optional[int] = None
enhancePrompt: Optional[bool] = None
generateAudio: Optional[bool] = Field(
None,
description='Generate audio for the video. Only supported by veo 3 models.',
)
negativePrompt: Optional[str] = None
personGeneration: Optional[PersonGeneration1] = None
sampleCount: Optional[int] = None
seed: Optional[int] = None
storageUri: Optional[str] = Field(
None, description='Optional Cloud Storage URI to upload the video'
)
class VeoGenVidRequest(BaseModel):
instances: Optional[list[Instance1]] = None
parameters: Optional[Parameters1] = None
class VeoGenVidResponse(BaseModel):
name: str = Field(
...,
description='Operation resource name',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
],
)
class VeoGenVidPollRequest(BaseModel):
operationName: str = Field(
...,
description='Full operation name (from predict response)',
examples=[
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
],
)
class Video(BaseModel):
bytesBase64Encoded: Optional[str] = Field(
None, description='Base64-encoded video content'
)
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
mimeType: Optional[str] = Field(None, description='Video MIME type')
class Error1(BaseModel):
code: Optional[int] = Field(None, description='Error code')
message: Optional[str] = Field(None, description='Error message')
class Response1(BaseModel):
field_type: Optional[str] = Field(
None,
alias='@type',
examples=[
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
],
)
raiMediaFilteredCount: Optional[int] = Field(
None, description='Count of media filtered by responsible AI policies'
)
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
class VeoGenVidPollResponse(BaseModel):
done: Optional[bool] = None
error: Optional[Error1] = Field(
None, description='Error details if operation failed'
)
name: Optional[str] = None
response: Optional[Response1] = Field(
None, description='The actual prediction response if done is true'
)

View File

@ -1,146 +1,47 @@
import asyncio
import io
from inspect import cleandoc from inspect import cleandoc
from typing import Union, Optional
import torch
from pydantic import BaseModel
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.bfl_api import ( from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
BFLFluxExpandImageRequest, BFLFluxExpandImageRequest,
BFLFluxFillImageRequest, BFLFluxFillImageRequest,
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxKontextProGenerateRequest, BFLFluxKontextProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateResponse, BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLStatus,
Flux2ProGenerateRequest,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_image_tensor,
SynchronousOperation, get_number_of_images,
) poll_op,
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_aspect_ratio,
process_image_response,
resize_mask_to_image, resize_mask_to_image,
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_string, validate_string,
) )
import numpy as np
from PIL import Image
import aiohttp
import torch
import base64
import time
from server import PromptServer
def convert_mask_to_image(mask: torch.Tensor): def convert_mask_to_image(mask: torch.Tensor):
""" """
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
""" """
mask = mask.unsqueeze(-1) mask = mask.unsqueeze(-1)
mask = torch.cat([mask]*3, dim=-1) mask = torch.cat([mask] * 3, dim=-1)
return mask return mask
async def handle_bfl_synchronous_operation(
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = await operation.execute()
return await _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)
async def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
retries_404 = 0
max_retries_404 = 5
retry_404_seconds = 2
retry_202_seconds = 2
retry_pending_seconds = 1
async with aiohttp.ClientSession() as session:
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
async with session.get(polling_url) as response:
if response.status == 200:
result = await response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
async with session.get(img_url) as img_resp:
return process_image_response(await img_resp.content.read())
elif result["status"] in [
BFLStatus.request_moderated,
BFLStatus.content_moderated,
]:
status = result["status"]
raise Exception(
f"BFL API did not return an image due to: {status}."
)
elif result["status"] == BFLStatus.error:
raise Exception(f"BFL API encountered an error: {result}.")
elif result["status"] == BFLStatus.pending:
await asyncio.sleep(retry_pending_seconds)
continue
elif response.status == 404:
if retries_404 < max_retries_404:
retries_404 += 1
await asyncio.sleep(retry_404_seconds)
continue
raise Exception(
f"BFL API could not find task after {max_retries_404} tries."
)
elif response.status == 202:
await asyncio.sleep(retry_202_seconds)
elif time.time() - start_time > timeout:
raise Exception(
f"BFL API experienced a timeout; could not return request under {timeout} seconds."
)
else:
raise Exception(f"BFL API encountered an error: {response.json()}")
def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
# remove batch dimension if present
if len(scaled_image.shape) > 3:
scaled_image = scaled_image[0]
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
return base64.b64encode(img_byte_arr.getvalue()).decode()
class FluxProUltraImageNode(IO.ComfyNode): class FluxProUltraImageNode(IO.ComfyNode):
""" """
Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -158,7 +59,9 @@ class FluxProUltraImageNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -203,16 +106,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
@classmethod @classmethod
def validate_inputs(cls, aspect_ratio: str): def validate_inputs(cls, aspect_ratio: str):
try: validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True return True
@classmethod @classmethod
@ -220,49 +114,44 @@ class FluxProUltraImageNode(IO.ComfyNode):
cls, cls,
prompt: str, prompt: str,
aspect_ratio: str, aspect_ratio: str,
prompt_upsampling=False, prompt_upsampling: bool = False,
raw=False, raw: bool = False,
seed=0, seed: int = 0,
image_prompt=None, image_prompt: torch.Tensor | None = None,
image_prompt_strength=0.1, image_prompt_strength: float = 0.1,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image_prompt is None: if image_prompt is None:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation( initial_response = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/bfl/flux-pro-1.1-ultra/generate", ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"),
method=HttpMethod.POST, response_model=BFLFluxProGenerateResponse,
request_model=BFLFluxProUltraGenerateRequest, data=BFLFluxProUltraGenerateRequest(
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxProUltraGenerateRequest(
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
seed=seed, seed=seed,
aspect_ratio=validate_aspect_ratio( aspect_ratio=aspect_ratio,
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
),
raw=raw, raw=raw,
image_prompt=( image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)),
image_prompt image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)),
if image_prompt is None
else convert_image_to_base64(image_prompt)
),
image_prompt_strength=(
None if image_prompt is None else round(image_prompt_strength, 2)
),
), ),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) response = await poll_op(
return IO.NodeOutput(output_image) cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxKontextProImageNode(IO.ComfyNode): class FluxKontextProImageNode(IO.ComfyNode):
@ -270,11 +159,6 @@ class FluxKontextProImageNode(IO.ComfyNode):
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
""" """
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
@ -347,46 +231,43 @@ class FluxKontextProImageNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
guidance: float, guidance: float,
steps: int, steps: int,
input_image: Optional[torch.Tensor]=None, input_image: torch.Tensor | None = None,
seed=0, seed=0,
prompt_upsampling=False, prompt_upsampling=False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
aspect_ratio = validate_aspect_ratio( validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1))
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
if input_image is None: if input_image is None:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation( initial_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=cls.BFL_PATH, ApiEndpoint(path=cls.BFL_PATH, method="POST"),
method=HttpMethod.POST, response_model=BFLFluxProGenerateResponse,
request_model=BFLFluxKontextProGenerateRequest, data=BFLFluxKontextProGenerateRequest(
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1), guidance=round(guidance, 1),
steps=steps, steps=steps,
seed=seed, seed=seed,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
input_image=( input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)),
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
), ),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) response = await poll_op(
return IO.NodeOutput(output_image) cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxKontextMaxImageNode(FluxKontextProImageNode): class FluxKontextMaxImageNode(FluxKontextProImageNode):
@ -400,117 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode):
DISPLAY_NAME = "Flux.1 Kontext [max] Image" DISPLAY_NAME = "Flux.1 Kontext [max] Image"
class FluxProImageNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and resolution.
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProImageNode",
display_name="Flux 1.1 [pro] Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Int.Input(
"width",
default=1024,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=1440,
step=32,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
IO.Image.Input(
"image_prompt",
optional=True,
),
# "image_prompt_strength": (
# IO.FLOAT,
# {
# "default": 0.1,
# "min": 0.0,
# "max": 1.0,
# "step": 0.01,
# "tooltip": "Blend between the prompt and the image prompt.",
# },
# ),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
prompt: str,
prompt_upsampling,
width: int,
height: int,
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
) -> IO.NodeOutput:
image_prompt = (
image_prompt
if image_prompt is None
else convert_image_to_base64(image_prompt)
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.1/generate",
method=HttpMethod.POST,
request_model=BFLFluxProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
width=width,
height=height,
seed=seed,
image_prompt=image_prompt,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
class FluxProExpandNode(IO.ComfyNode): class FluxProExpandNode(IO.ComfyNode):
""" """
Outpaints image based on prompt. Outpaints image based on prompt.
@ -534,7 +304,9 @@ class FluxProExpandNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Int.Input( IO.Int.Input(
"top", "top",
@ -610,16 +382,11 @@ class FluxProExpandNode(IO.ComfyNode):
guidance: float, guidance: float,
seed=0, seed=0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
image = convert_image_to_base64(image) initial_response = await sync_op(
cls,
operation = SynchronousOperation( ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"),
endpoint=ApiEndpoint( response_model=BFLFluxProGenerateResponse,
path="/proxy/bfl/flux-pro-1.0-expand/generate", data=BFLFluxExpandImageRequest(
method=HttpMethod.POST,
request_model=BFLFluxExpandImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxExpandImageRequest(
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
top=top, top=top,
@ -629,16 +396,25 @@ class FluxProExpandNode(IO.ComfyNode):
steps=steps, steps=steps,
guidance=guidance, guidance=guidance,
seed=seed, seed=seed,
image=image, image=tensor_to_base64_string(image),
), ),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) response = await poll_op(
return IO.NodeOutput(output_image) cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProFillNode(IO.ComfyNode): class FluxProFillNode(IO.ComfyNode):
@ -665,7 +441,9 @@ class FluxProFillNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"prompt_upsampling", "prompt_upsampling",
default=False, default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
), ),
IO.Float.Input( IO.Float.Input(
"guidance", "guidance",
@ -712,94 +490,68 @@ class FluxProFillNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
# prepare mask # prepare mask
mask = resize_mask_to_image(mask, image) mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask)) mask = tensor_to_base64_string(convert_mask_to_image(mask))
# make sure image will have alpha channel removed initial_response = await sync_op(
image = convert_image_to_base64(image[:, :, :, :3]) cls,
ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"),
operation = SynchronousOperation( response_model=BFLFluxProGenerateResponse,
endpoint=ApiEndpoint( data=BFLFluxFillImageRequest(
path="/proxy/bfl/flux-pro-1.0-fill/generate",
method=HttpMethod.POST,
request_model=BFLFluxFillImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxFillImageRequest(
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
steps=steps, steps=steps,
guidance=guidance, guidance=guidance,
seed=seed, seed=seed,
image=image, image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask, mask=mask,
), ),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) response = await poll_op(
return IO.NodeOutput(output_image) cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxProCannyNode(IO.ComfyNode): class Flux2ProImageNode(IO.ComfyNode):
"""
Generate image using a control image (canny).
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="FluxProCannyNode", node_id="Flux2ProImageNode",
display_name="Flux.1 Canny Control Image", display_name="Flux.2 [pro] Image",
category="api node/image/BFL", category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously based on prompt and resolution.",
inputs=[ inputs=[
IO.Image.Input("control_image"),
IO.String.Input( IO.String.Input(
"prompt", "prompt",
multiline=True, multiline=True,
default="", default="",
tooltip="Prompt for the image generation", tooltip="Prompt for the image generation or edit",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Float.Input(
"canny_low_threshold",
default=0.1,
min=0.01,
max=0.99,
step=0.01,
tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True",
),
IO.Float.Input(
"canny_high_threshold",
default=0.4,
min=0.01,
max=0.99,
step=0.01,
tooltip="High threshold for Canny edge detection; ignored if skip_processing is True",
),
IO.Boolean.Input(
"skip_preprocessing",
default=False,
tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
),
IO.Float.Input(
"guidance",
default=30,
min=1,
max=100,
tooltip="Guidance strength for the image generation process",
), ),
IO.Int.Input( IO.Int.Input(
"steps", "width",
default=50, default=1024,
min=15, min=256,
max=50, max=2048,
tooltip="Number of steps for the image generation process", step=32,
),
IO.Int.Input(
"height",
default=768,
min=256,
max=2048,
step=32,
), ),
IO.Int.Input( IO.Int.Input(
"seed", "seed",
@ -809,6 +561,14 @@ class FluxProCannyNode(IO.ComfyNode):
control_after_generate=True, control_after_generate=True,
tooltip="The random seed used for creating the noise.", tooltip="The random seed used for creating the noise.",
), ),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. "
"If active, automatically modifies the prompt for more creative generation, "
"but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
hidden=[ hidden=[
@ -822,162 +582,54 @@ class FluxProCannyNode(IO.ComfyNode):
@classmethod @classmethod
async def execute( async def execute(
cls, cls,
control_image: torch.Tensor,
prompt: str, prompt: str,
width: int,
height: int,
seed: int,
prompt_upsampling: bool, prompt_upsampling: bool,
canny_low_threshold: float, images: torch.Tensor | None = None,
canny_high_threshold: float,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
control_image = convert_image_to_base64(control_image[:, :, :, :3]) reference_images = {}
preprocessed_image = None if images is not None:
if get_number_of_images(images) > 9:
# scale canny threshold between 0-500, to match BFL's API raise ValueError("The current maximum number of supported images is 9.")
def scale_value(value: float, min_val=0, max_val=500): for image_index in range(images.shape[0]):
return min_val + value * (max_val - min_val) key_name = f"input_image_{image_index + 1}" if image_index else "input_image"
canny_low_threshold = int(round(scale_value(canny_low_threshold))) reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
canny_high_threshold = int(round(scale_value(canny_high_threshold))) initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
if skip_preprocessing: response_model=BFLFluxProGenerateResponse,
preprocessed_image = control_image data=Flux2ProGenerateRequest(
control_image = None
canny_low_threshold = None
canny_high_threshold = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-canny/generate",
method=HttpMethod.POST,
request_model=BFLFluxCannyImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxCannyImageRequest(
prompt=prompt, prompt=prompt,
prompt_upsampling=prompt_upsampling, width=width,
steps=steps, height=height,
guidance=guidance,
seed=seed, seed=seed,
control_image=control_image,
canny_low_threshold=canny_low_threshold,
canny_high_threshold=canny_high_threshold,
preprocessed_image=preprocessed_image,
),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
)
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image)
class FluxProDepthNode(IO.ComfyNode):
"""
Generate image using a control image (depth).
"""
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxProDepthNode",
display_name="Flux.1 Depth Control Image",
category="api node/image/BFL",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("control_image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt for the image generation",
),
IO.Boolean.Input(
"prompt_upsampling",
default=False,
tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
),
IO.Boolean.Input(
"skip_preprocessing",
default=False,
tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
),
IO.Float.Input(
"guidance",
default=15,
min=1,
max=100,
tooltip="Guidance strength for the image generation process",
),
IO.Int.Input(
"steps",
default=50,
min=15,
max=50,
tooltip="Number of steps for the image generation process",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
control_image: torch.Tensor,
prompt: str,
prompt_upsampling: bool,
skip_preprocessing: bool,
steps: int,
guidance: float,
seed=0,
) -> IO.NodeOutput:
control_image = convert_image_to_base64(control_image[:,:,:,:3])
preprocessed_image = None
if skip_preprocessing:
preprocessed_image = control_image
control_image = None
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/bfl/flux-pro-1.0-depth/generate",
method=HttpMethod.POST,
request_model=BFLFluxDepthImageRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxDepthImageRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling, prompt_upsampling=prompt_upsampling,
steps=steps, **reference_images,
guidance=guidance,
seed=seed,
control_image=control_image,
preprocessed_image=preprocessed_image,
), ),
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id)
return IO.NodeOutput(output_image) def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class BFLExtension(ComfyExtension): class BFLExtension(ComfyExtension):
@ -985,13 +637,11 @@ class BFLExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
FluxProUltraImageNode, FluxProUltraImageNode,
# FluxProImageNode,
FluxKontextProImageNode, FluxKontextProImageNode,
FluxKontextMaxImageNode, FluxKontextMaxImageNode,
FluxProExpandNode, FluxProExpandNode,
FluxProFillNode, FluxProFillNode,
FluxProCannyNode, Flux2ProImageNode,
FluxProDepthNode,
] ]

View File

@ -1,35 +1,27 @@
import logging import logging
import math import math
from enum import Enum from enum import Enum
from typing import Literal, Optional, Type, Union from typing import Literal, Optional, Union
from typing_extensions import override
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util.validation_utils import ( from comfy_api_nodes.util import (
validate_image_aspect_ratio_range,
get_number_of_images,
validate_image_dimensions,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
EmptyRequest,
HttpMethod,
SynchronousOperation,
PollingOperation,
T,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor, download_url_to_image_tensor,
download_url_to_video_output, download_url_to_video_output,
upload_images_to_comfyapi, get_number_of_images,
validate_string,
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
) )
BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations"
# Long-running tasks endpoints(e.g., video) # Long-running tasks endpoints(e.g., video)
@ -46,13 +38,14 @@ class Image2ImageModelName(str, Enum):
class Text2VideoModelName(str, Enum): class Text2VideoModelName(str, Enum):
seedance_1_pro = "seedance-1-0-pro-250528" seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-t2v-250428" seedance_1_lite = "seedance-1-0-lite-t2v-250428"
class Image2VideoModelName(str, Enum): class Image2VideoModelName(str, Enum):
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_pro = "seedance-1-0-pro-250528"
seedance_1_lite = "seedance-1-0-lite-i2v-250428" seedance_1_lite = "seedance-1-0-lite-i2v-250428"
@ -208,35 +201,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
return None return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
task_id: str,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
completed_statuses=[
"succeeded",
],
failed_statuses=[
"cancelled",
"failed",
],
status_extractor=lambda response: response.status,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
).execute()
class ByteDanceImageNode(IO.ComfyNode): class ByteDanceImageNode(IO.ComfyNode):
@classmethod @classmethod
@ -303,7 +267,7 @@ class ByteDanceImageNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image", tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
], ],
@ -341,8 +305,7 @@ class ByteDanceImageNode(IO.ComfyNode):
w, h = width, height w, h = width, height
if not (512 <= w <= 2048) or not (512 <= h <= 2048): if not (512 <= w <= 2048) or not (512 <= h <= 2048):
raise ValueError( raise ValueError(
f"Custom size out of range: {w}x{h}. " f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels."
"Both width and height must be between 512 and 2048 pixels."
) )
payload = Text2ImageTaskCreationRequest( payload = Text2ImageTaskCreationRequest(
@ -353,20 +316,12 @@ class ByteDanceImageNode(IO.ComfyNode):
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
watermark=watermark, watermark=watermark,
) )
auth_kwargs = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
} data=payload,
response = await SynchronousOperation( response_model=ImageTaskCreationResponse,
endpoint=ApiEndpoint( )
path=BYTEPLUS_IMAGE_ENDPOINT,
method=HttpMethod.POST,
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
@ -420,7 +375,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image", tooltip='Whether to add an "AI generated" watermark to the image',
optional=True, optional=True,
), ),
], ],
@ -448,17 +403,8 @@ class ByteDanceImageEditNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) validate_image_aspect_ratio(image, (1, 3), (3, 1))
auth_kwargs = { source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0]
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
source_url = (await upload_images_to_comfyapi(
image,
max_images=1,
mime_type="image/png",
auth_kwargs=auth_kwargs,
))[0]
payload = Image2ImageTaskCreationRequest( payload = Image2ImageTaskCreationRequest(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -467,16 +413,12 @@ class ByteDanceImageEditNode(IO.ComfyNode):
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
watermark=watermark, watermark=watermark,
) )
response = await SynchronousOperation( response = await sync_op(
endpoint=ApiEndpoint( cls,
path=BYTEPLUS_IMAGE_ENDPOINT, ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
method=HttpMethod.POST, data=payload,
request_model=Image2ImageTaskCreationRequest, response_model=ImageTaskCreationResponse,
response_model=ImageTaskCreationResponse, )
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
@ -504,7 +446,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Image.Input( IO.Image.Input(
"image", "image",
tooltip="Input image(s) for image-to-image generation. " tooltip="Input image(s) for image-to-image generation. "
"List of 1-10 images for single or multi-reference generation.", "List of 1-10 images for single or multi-reference generation.",
optional=True, optional=True,
), ),
IO.Combo.Input( IO.Combo.Input(
@ -534,9 +476,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
"sequential_image_generation", "sequential_image_generation",
options=["disabled", "auto"], options=["disabled", "auto"],
tooltip="Group image generation mode. " tooltip="Group image generation mode. "
"'disabled' generates a single image. " "'disabled' generates a single image. "
"'auto' lets the model decide whether to generate multiple related images " "'auto' lets the model decide whether to generate multiple related images "
"(e.g., story scenes, character variations).", "(e.g., story scenes, character variations).",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -547,7 +489,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
step=1, step=1,
display_mode=IO.NumberDisplay.number, display_mode=IO.NumberDisplay.number,
tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " tooltip="Maximum number of images to generate when sequential_image_generation='auto'. "
"Total images (input + generated) cannot exceed 15.", "Total images (input + generated) cannot exceed 15.",
optional=True, optional=True,
), ),
IO.Int.Input( IO.Int.Input(
@ -564,7 +506,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the image.", tooltip='Whether to add an "AI generated" watermark to the image.',
optional=True, optional=True,
), ),
IO.Boolean.Input( IO.Boolean.Input(
@ -611,8 +553,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
w, h = width, height w, h = width, height
if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): if not (1024 <= w <= 4096) or not (1024 <= h <= 4096):
raise ValueError( raise ValueError(
f"Custom size out of range: {w}x{h}. " f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
"Both width and height must be between 1024 and 4096 pixels."
) )
n_input_images = get_number_of_images(image) if image is not None else 0 n_input_images = get_number_of_images(image) if image is not None else 0
if n_input_images > 10: if n_input_images > 10:
@ -621,41 +562,31 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
raise ValueError( raise ValueError(
"The maximum number of generated images plus the number of reference images cannot exceed 15." "The maximum number of generated images plus the number of reference images cannot exceed 15."
) )
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
reference_images_urls = [] reference_images_urls = []
if n_input_images: if n_input_images:
for i in image: for i in image:
validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) validate_image_aspect_ratio(i, (1, 3), (3, 1))
reference_images_urls = (await upload_images_to_comfyapi( reference_images_urls = await upload_images_to_comfyapi(
cls,
image, image,
max_images=n_input_images, max_images=n_input_images,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs, )
)) response = await sync_op(
payload = Seedream4TaskCreationRequest( cls,
model=model, ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"),
prompt=prompt, response_model=ImageTaskCreationResponse,
image=reference_images_urls, data=Seedream4TaskCreationRequest(
size=f"{w}x{h}", model=model,
seed=seed, prompt=prompt,
sequential_image_generation=sequential_image_generation, image=reference_images_urls,
sequential_image_generation_options=Seedream4Options(max_images=max_images), size=f"{w}x{h}",
watermark=watermark, seed=seed,
) sequential_image_generation=sequential_image_generation,
response = await SynchronousOperation( sequential_image_generation_options=Seedream4Options(max_images=max_images),
endpoint=ApiEndpoint( watermark=watermark,
path=BYTEPLUS_IMAGE_ENDPOINT,
method=HttpMethod.POST,
request_model=Seedream4TaskCreationRequest,
response_model=ImageTaskCreationResponse,
), ),
request=payload, )
auth_kwargs=auth_kwargs,
).execute()
if len(response.data) == 1: if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response)))
urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d]
@ -719,13 +650,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
"camera_fixed", "camera_fixed",
default=False, default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction " tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.", "to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True, optional=True,
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.", tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
], ],
@ -764,19 +695,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
f"--camerafixed {str(camera_fixed).lower()} " f"--camerafixed {str(camera_fixed).lower()} "
f"--watermark {str(watermark).lower()}" f"--watermark {str(watermark).lower()}"
) )
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
return await process_video_task( return await process_video_task(
request_model=Text2VideoTaskCreationRequest, cls,
payload=Text2VideoTaskCreationRequest( payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]),
model=model,
content=[TaskTextContent(text=prompt)],
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
) )
@ -840,13 +761,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
"camera_fixed", "camera_fixed",
default=False, default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction " tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.", "to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True, optional=True,
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.", tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
], ],
@ -877,15 +798,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
validate_string(prompt, strip_whitespace=True, min_length=1) validate_string(prompt, strip_whitespace=True, min_length=1)
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0]
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
prompt = ( prompt = (
f"{prompt} " f"{prompt} "
f"--resolution {resolution} " f"--resolution {resolution} "
@ -897,13 +812,11 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
) )
return await process_video_task( return await process_video_task(
request_model=Image2VideoTaskCreationRequest, cls,
payload=Image2VideoTaskCreationRequest( payload=Image2VideoTaskCreationRequest(
model=model, model=model,
content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))],
), ),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
) )
@ -971,13 +884,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
"camera_fixed", "camera_fixed",
default=False, default=False,
tooltip="Specifies whether to fix the camera. The platform appends an instruction " tooltip="Specifies whether to fix the camera. The platform appends an instruction "
"to fix the camera to your prompt, but does not guarantee the actual effect.", "to fix the camera to your prompt, but does not guarantee the actual effect.",
optional=True, optional=True,
), ),
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.", tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
], ],
@ -1010,18 +923,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"])
for i in (first_frame, last_frame): for i in (first_frame, last_frame):
validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls,
image_tensor_pair_to_batch(first_frame, last_frame), image_tensor_pair_to_batch(first_frame, last_frame),
max_images=2, max_images=2,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs,
) )
prompt = ( prompt = (
@ -1035,7 +943,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
) )
return await process_video_task( return await process_video_task(
request_model=Image2VideoTaskCreationRequest, cls,
payload=Image2VideoTaskCreationRequest( payload=Image2VideoTaskCreationRequest(
model=model, model=model,
content=[ content=[
@ -1044,8 +952,6 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"),
], ],
), ),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
) )
@ -1108,7 +1014,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the video.", tooltip='Whether to add an "AI generated" watermark to the video.',
optional=True, optional=True,
), ),
], ],
@ -1139,17 +1045,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"])
for image in images: for image in images:
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
image_urls = await upload_images_to_comfyapi(
images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs
)
image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png")
prompt = ( prompt = (
f"{prompt} " f"{prompt} "
f"--resolution {resolution} " f"--resolution {resolution} "
@ -1160,42 +1058,32 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
) )
x = [ x = [
TaskTextContent(text=prompt), TaskTextContent(text=prompt),
*[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls],
] ]
return await process_video_task( return await process_video_task(
request_model=Image2VideoTaskCreationRequest, cls,
payload=Image2VideoTaskCreationRequest( payload=Image2VideoTaskCreationRequest(model=model, content=x),
model=model,
content=x,
),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
) )
async def process_video_task( async def process_video_task(
request_model: Type[T], cls: type[IO.ComfyNode],
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
auth_kwargs: dict,
node_id: str,
estimated_duration: Optional[int], estimated_duration: Optional[int],
) -> IO.NodeOutput: ) -> IO.NodeOutput:
initial_response = await SynchronousOperation( initial_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=BYTEPLUS_TASK_ENDPOINT, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
method=HttpMethod.POST, data=payload,
request_model=request_model, response_model=TaskCreationResponse,
response_model=TaskCreationResponse, )
), response = await poll_op(
request=payload, cls,
auth_kwargs=auth_kwargs, ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"),
).execute() status_extractor=lambda r: r.status,
response = await poll_until_finished(
auth_kwargs,
initial_response.id,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
node_id=node_id, response_model=TaskStatusResponse,
) )
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
@ -1221,5 +1109,6 @@ class ByteDanceExtension(ComfyExtension):
ByteDanceImageReferenceNode, ByteDanceImageReferenceNode,
] ]
async def comfy_entrypoint() -> ByteDanceExtension: async def comfy_entrypoint() -> ByteDanceExtension:
return ByteDanceExtension() return ByteDanceExtension()

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
from io import BytesIO from io import BytesIO
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch import torch
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
IdeogramV3Request, IdeogramV3Request,
IdeogramV3EditRequest, IdeogramV3EditRequest,
) )
from comfy_api_nodes.util import (
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
bytesio_to_image_tensor, bytesio_to_image_tensor,
download_url_as_bytesio,
resize_mask_to_image, resize_mask_to_image,
sync_op,
) )
from server import PromptServer
V1_V1_RES_MAP = { V1_V1_RES_MAP = {
"Auto":"AUTO", "Auto":"AUTO",
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
for image_url in image_urls: for image_url in image_urls:
# Using functions from apinode_utils.py to handle downloading and processing # Using functions from apinode_utils.py to handle downloading and processing
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
image_tensors.append(img_tensor) image_tensors.append(img_tensor)
@ -233,19 +227,6 @@ async def download_and_process_images(image_urls):
return stacked_tensors return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(IO.ComfyNode): class IdeogramV1(IO.ComfyNode):
@classmethod @classmethod
@ -334,44 +315,30 @@ class IdeogramV1(IO.ComfyNode):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
model = "V_1_TURBO" if turbo else "V_1" model = "V_1_TURBO" if turbo else "V_1"
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
num_images=num_images, num_images=num_images,
seed=seed, seed=seed,
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -500,18 +467,11 @@ class IdeogramV2(IO.ComfyNode):
else: else:
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
auth = { response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
} response_model=IdeogramGenerateResponse,
operation = SynchronousOperation( data=IdeogramGenerateRequest(
endpoint=ApiEndpoint(
path="/proxy/ideogram/generate",
method=HttpMethod.POST,
request_model=IdeogramGenerateRequest,
response_model=IdeogramGenerateResponse,
),
request=IdeogramGenerateRequest(
image_request=ImageRequest( image_request=ImageRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
@ -519,28 +479,20 @@ class IdeogramV2(IO.ComfyNode):
seed=seed, seed=seed,
aspect_ratio=final_aspect_ratio, aspect_ratio=final_aspect_ratio,
resolution=final_resolution, resolution=final_resolution,
magic_prompt_option=( magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
magic_prompt_option if magic_prompt_option != "AUTO" else None
),
style_type=style_type if style_type != "NONE" else None, style_type=style_type if style_type != "NONE" else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
color_palette=color_palette if color_palette else None, color_palette=color_palette if color_palette else None,
) )
), ),
auth_kwargs=auth, max_retries=1,
) )
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -656,10 +608,6 @@ class IdeogramV3(IO.ComfyNode):
character_image=None, character_image=None,
character_mask=None, character_mask=None,
): ):
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if rendering_speed == "BALANCED": # for backward compatibility if rendering_speed == "BALANCED": # for backward compatibility
rendering_speed = "DEFAULT" rendering_speed = "DEFAULT"
@ -694,9 +642,6 @@ class IdeogramV3(IO.ComfyNode):
# Check if both image and mask are provided for editing mode # Check if both image and mask are provided for editing mode
if image is not None and mask is not None: if image is not None and mask is not None:
# Edit mode
path = "/proxy/ideogram/ideogram-v3/edit"
# Process image and mask # Process image and mask
input_tensor = image.squeeze().cpu() input_tensor = image.squeeze().cpu()
# Resize mask to match image dimension # Resize mask to match image dimension
@ -749,27 +694,20 @@ class IdeogramV3(IO.ComfyNode):
if character_mask_binary: if character_mask_binary:
files["character_mask_binary"] = character_mask_binary files["character_mask_binary"] = character_mask_binary
# Execute the operation for edit mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=edit_request,
request_model=IdeogramV3EditRequest,
response_model=IdeogramGenerateResponse,
),
request=edit_request,
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
elif image is not None or mask is not None: elif image is not None or mask is not None:
# If only one of image or mask is provided, raise an error # If only one of image or mask is provided, raise an error
raise Exception("Ideogram V3 image editing requires both an image AND a mask") raise Exception("Ideogram V3 image editing requires both an image AND a mask")
else: else:
# Generation mode
path = "/proxy/ideogram/ideogram-v3/generate"
# Create generation request # Create generation request
gen_request = IdeogramV3Request( gen_request = IdeogramV3Request(
prompt=prompt, prompt=prompt,
@ -800,32 +738,22 @@ class IdeogramV3(IO.ComfyNode):
if files: if files:
gen_request.style_type = "AUTO" gen_request.style_type = "AUTO"
# Execute the operation for generation mode response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
path=path, response_model=IdeogramGenerateResponse,
method=HttpMethod.POST, data=gen_request,
request_model=IdeogramV3Request,
response_model=IdeogramGenerateResponse,
),
request=gen_request,
files=files if files else None, files=files if files else None,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth, max_retries=1,
) )
# Execute the operation and process response
response = await operation.execute()
if not response.data or len(response.data) == 0: if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response") raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url] image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls: if not image_urls:
raise Exception("No image URLs were generated in the response") raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
return IO.NodeOutput(await download_and_process_images(image_urls)) return IO.NodeOutput(await download_and_process_images(image_urls))
@ -838,5 +766,6 @@ class IdeogramExtension(ComfyExtension):
IdeogramV3, IdeogramV3,
] ]
async def comfy_entrypoint() -> IdeogramExtension: async def comfy_entrypoint() -> IdeogramExtension:
return IdeogramExtension() return IdeogramExtension()

View File

@ -5,8 +5,7 @@ For source of truth on the allowed permutations of request fields, please refere
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, TypeVar, Any from typing import Optional, TypeVar
from collections.abc import Callable
import math import math
import logging import logging
@ -15,7 +14,6 @@ from typing_extensions import override
import torch import torch
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
KlingTaskStatus,
KlingCameraControl, KlingCameraControl,
KlingCameraConfig, KlingCameraConfig,
KlingCameraControlType, KlingCameraControlType,
@ -52,26 +50,20 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName, KlingCharacterEffectModelName,
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string,
download_url_to_video_output,
upload_video_to_comfyapi,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
validate_string,
)
from comfy_api_nodes.util.validation_utils import (
validate_image_dimensions, validate_image_dimensions,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
tensor_to_base64_string,
validate_string,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
upload_video_to_comfyapi,
download_url_to_video_output,
sync_op,
ApiEndpoint,
poll_op,
) )
from comfy_api.input_impl import VideoFromFile from comfy_api.input_impl import VideoFromFile
from comfy_api.input.basic_types import AudioInput from comfy_api.input.basic_types import AudioInput
@ -214,34 +206,6 @@ VOICES_CONFIG = {
} }
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
KlingTaskStatus.succeed.value,
],
failed_statuses=[KlingTaskStatus.failed.value],
status_extractor=lambda response: (
response.data.task_status.value
if response.data and response.data.task_status
else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def is_valid_camera_control_configs(configs: list[float]) -> bool: def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero.""" """Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs) return any(not math.isclose(value, 0.0) for value in configs)
@ -318,7 +282,7 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
""" """
validate_image_dimensions(image, min_width=300, min_height=300) validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1))
def get_video_from_response(response) -> KlingVideoResult: def get_video_from_response(response) -> KlingVideoResult:
@ -377,8 +341,7 @@ async def image_result_to_node_output(
async def execute_text2video( async def execute_text2video(
auth_kwargs: dict[str, str], cls: type[IO.ComfyNode],
node_id: str,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
cfg_scale: float, cfg_scale: float,
@ -389,14 +352,11 @@ async def execute_text2video(
camera_control: Optional[KlingCameraControl] = None, camera_control: Optional[KlingCameraControl] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
initial_operation = SynchronousOperation( task_creation_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=PATH_TEXT_TO_VIDEO, ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
method=HttpMethod.POST, response_model=KlingText2VideoResponse,
request_model=KlingText2VideoRequest, data=KlingText2VideoRequest(
response_model=KlingText2VideoResponse,
),
request=KlingText2VideoRequest(
prompt=prompt if prompt else None, prompt=prompt if prompt else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
duration=KlingVideoGenDuration(duration), duration=KlingVideoGenDuration(duration),
@ -406,24 +366,17 @@ async def execute_text2video(
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
camera_control=camera_control, camera_control=camera_control,
), ),
auth_kwargs=auth_kwargs,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"),
path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", response_model=KlingText2VideoResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingText2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
node_id=node_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
@ -432,8 +385,7 @@ async def execute_text2video(
async def execute_image2video( async def execute_image2video(
auth_kwargs: dict[str, str], cls: type[IO.ComfyNode],
node_id: str,
start_frame: torch.Tensor, start_frame: torch.Tensor,
prompt: str, prompt: str,
negative_prompt: str, negative_prompt: str,
@ -455,14 +407,11 @@ async def execute_image2video(
if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value:
model_mode = "pro" # October 5: currently "std" mode is not supported for this model model_mode = "pro" # October 5: currently "std" mode is not supported for this model
initial_operation = SynchronousOperation( task_creation_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=PATH_IMAGE_TO_VIDEO, ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
method=HttpMethod.POST, response_model=KlingImage2VideoResponse,
request_model=KlingImage2VideoRequest, data=KlingImage2VideoRequest(
response_model=KlingImage2VideoResponse,
),
request=KlingImage2VideoRequest(
model_name=KlingVideoGenModelName(model_name), model_name=KlingVideoGenModelName(model_name),
image=tensor_to_base64_string(start_frame), image=tensor_to_base64_string(start_frame),
image_tail=( image_tail=(
@ -477,24 +426,17 @@ async def execute_image2video(
duration=KlingVideoGenDuration(duration), duration=KlingVideoGenDuration(duration),
camera_control=camera_control, camera_control=camera_control,
), ),
auth_kwargs=auth_kwargs,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", response_model=KlingImage2VideoResponse,
method=HttpMethod.GET,
request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
node_id=node_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
@ -503,8 +445,7 @@ async def execute_image2video(
async def execute_video_effect( async def execute_video_effect(
auth_kwargs: dict[str, str], cls: type[IO.ComfyNode],
node_id: str,
dual_character: bool, dual_character: bool,
effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
model_name: str, model_name: str,
@ -530,35 +471,25 @@ async def execute_video_effect(
duration=duration, duration=duration,
) )
initial_operation = SynchronousOperation( task_creation_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=PATH_VIDEO_EFFECTS, endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"),
method=HttpMethod.POST, response_model=KlingVideoEffectsResponse,
request_model=KlingVideoEffectsRequest, data=KlingVideoEffectsRequest(
response_model=KlingVideoEffectsResponse,
),
request=KlingVideoEffectsRequest(
effect_scene=effect_scene, effect_scene=effect_scene,
input=request_input_field, input=request_input_field,
), ),
auth_kwargs=auth_kwargs,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"),
path=f"{PATH_VIDEO_EFFECTS}/{task_id}", response_model=KlingVideoEffectsResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoEffectsResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
node_id=node_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
@ -567,8 +498,7 @@ async def execute_video_effect(
async def execute_lipsync( async def execute_lipsync(
auth_kwargs: dict[str, str], cls: type[IO.ComfyNode],
node_id: str,
video: VideoInput, video: VideoInput,
audio: Optional[AudioInput] = None, audio: Optional[AudioInput] = None,
voice_language: Optional[str] = None, voice_language: Optional[str] = None,
@ -583,24 +513,23 @@ async def execute_lipsync(
validate_video_duration(video, 2, 10) validate_video_duration(video, 2, 10)
# Upload video to Comfy API and get download URL # Upload video to Comfy API and get download URL
video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) video_url = await upload_video_to_comfyapi(cls, video)
logging.info("Uploaded video to Comfy API. URL: %s", video_url) logging.info("Uploaded video to Comfy API. URL: %s", video_url)
# Upload the audio file to Comfy API and get download URL # Upload the audio file to Comfy API and get download URL
if audio: if audio:
audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) audio_url = await upload_audio_to_comfyapi(
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
)
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
else: else:
audio_url = None audio_url = None
initial_operation = SynchronousOperation( task_creation_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=PATH_LIP_SYNC, ApiEndpoint(PATH_LIP_SYNC, "POST"),
method=HttpMethod.POST, response_model=KlingLipSyncResponse,
request_model=KlingLipSyncRequest, data=KlingLipSyncRequest(
response_model=KlingLipSyncResponse,
),
request=KlingLipSyncRequest(
input=KlingLipSyncInputObject( input=KlingLipSyncInputObject(
video_url=video_url, video_url=video_url,
mode=model_mode, mode=model_mode,
@ -612,24 +541,17 @@ async def execute_lipsync(
voice_id=voice_id, voice_id=voice_id,
), ),
), ),
auth_kwargs=auth_kwargs,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"),
path=f"{PATH_LIP_SYNC}/{task_id}", response_model=KlingLipSyncResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingLipSyncResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_LIP_SYNC, estimated_duration=AVERAGE_DURATION_LIP_SYNC,
node_id=node_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
@ -807,11 +729,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] model_mode, duration, model_name = MODE_TEXT2VIDEO[mode]
return await execute_text2video( return await execute_text2video(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
@ -872,11 +790,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
camera_control: Optional[KlingCameraControl] = None, camera_control: Optional[KlingCameraControl] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_text2video( return await execute_text2video(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
model_name=KlingVideoGenModelName.kling_v1, model_name=KlingVideoGenModelName.kling_v1,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
model_mode=KlingVideoGenMode.std, model_mode=KlingVideoGenMode.std,
@ -944,11 +858,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
end_frame: Optional[torch.Tensor] = None, end_frame: Optional[torch.Tensor] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_image2video( return await execute_image2video(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
start_frame=start_frame, start_frame=start_frame,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
@ -1017,11 +927,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
camera_control: KlingCameraControl, camera_control: KlingCameraControl,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_image2video( return await execute_image2video(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
model_name=KlingVideoGenModelName.kling_v1_5, model_name=KlingVideoGenModelName.kling_v1_5,
start_frame=start_frame, start_frame=start_frame,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
@ -1097,11 +1003,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
mode, duration, model_name = MODE_START_END_FRAME[mode] mode, duration, model_name = MODE_START_END_FRAME[mode]
return await execute_image2video( return await execute_image2video(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model_name=model_name, model_name=model_name,
@ -1162,41 +1064,27 @@ class KlingVideoExtendNode(IO.ComfyNode):
video_id: str, video_id: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
auth = { task_creation_response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"),
} response_model=KlingVideoExtendResponse,
initial_operation = SynchronousOperation( data=KlingVideoExtendRequest(
endpoint=ApiEndpoint(
path=PATH_VIDEO_EXTEND,
method=HttpMethod.POST,
request_model=KlingVideoExtendRequest,
response_model=KlingVideoExtendResponse,
),
request=KlingVideoExtendRequest(
prompt=prompt if prompt else None, prompt=prompt if prompt else None,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
video_id=video_id, video_id=video_id,
), ),
auth_kwargs=auth,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"),
path=f"{PATH_VIDEO_EXTEND}/{task_id}", response_model=KlingVideoExtendResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVideoExtendResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
node_id=cls.hidden.unique_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_video_result_response(final_response) validate_video_result_response(final_response)
@ -1259,11 +1147,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
duration: KlingVideoGenDuration, duration: KlingVideoGenDuration,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
video, _, duration = await execute_video_effect( video, _, duration = await execute_video_effect(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=True, dual_character=True,
effect_scene=effect_scene, effect_scene=effect_scene,
model_name=model_name, model_name=model_name,
@ -1324,11 +1208,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.NodeOutput( return IO.NodeOutput(
*( *(
await execute_video_effect( await execute_video_effect(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
dual_character=False, dual_character=False,
effect_scene=effect_scene, effect_scene=effect_scene,
model_name=model_name, model_name=model_name,
@ -1379,11 +1259,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
voice_language: str, voice_language: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await execute_lipsync( return await execute_lipsync(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
video=video, video=video,
audio=audio, audio=audio,
voice_language=voice_language, voice_language=voice_language,
@ -1445,11 +1321,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
voice_id, voice_language = VOICES_CONFIG[voice] voice_id, voice_language = VOICES_CONFIG[voice]
return await execute_lipsync( return await execute_lipsync(
auth_kwargs={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
video=video, video=video,
text=text, text=text,
voice_language=voice_language, voice_language=voice_language,
@ -1496,40 +1368,26 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
cloth_image: torch.Tensor, cloth_image: torch.Tensor,
model_name: KlingVirtualTryOnModelName, model_name: KlingVirtualTryOnModelName,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = { task_creation_response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"),
} response_model=KlingVirtualTryOnResponse,
initial_operation = SynchronousOperation( data=KlingVirtualTryOnRequest(
endpoint=ApiEndpoint(
path=PATH_VIRTUAL_TRY_ON,
method=HttpMethod.POST,
request_model=KlingVirtualTryOnRequest,
response_model=KlingVirtualTryOnResponse,
),
request=KlingVirtualTryOnRequest(
human_image=tensor_to_base64_string(human_image), human_image=tensor_to_base64_string(human_image),
cloth_image=tensor_to_base64_string(cloth_image), cloth_image=tensor_to_base64_string(cloth_image),
model_name=model_name, model_name=model_name,
), ),
auth_kwargs=auth,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"),
path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", response_model=KlingVirtualTryOnResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingVirtualTryOnResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
node_id=cls.hidden.unique_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_image_result_response(final_response) validate_image_result_response(final_response)
@ -1625,18 +1483,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
else: else:
image = tensor_to_base64_string(image) image = tensor_to_base64_string(image)
auth = { task_creation_response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
} response_model=KlingImageGenerationsResponse,
initial_operation = SynchronousOperation( data=KlingImageGenerationsRequest(
endpoint=ApiEndpoint(
path=PATH_IMAGE_GENERATIONS,
method=HttpMethod.POST,
request_model=KlingImageGenerationsRequest,
response_model=KlingImageGenerationsResponse,
),
request=KlingImageGenerationsRequest(
model_name=model_name, model_name=model_name,
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
@ -1647,24 +1498,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
n=n, n=n,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
), ),
auth_kwargs=auth,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id task_id = task_creation_response.data.task_id
final_response = await poll_until_finished( final_response = await poll_op(
auth, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"),
path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", response_model=KlingImageGenerationsResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=KlingImageGenerationsResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_IMAGE_GEN, estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
node_id=cls.hidden.unique_id, status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
) )
validate_image_result_response(final_response) validate_image_result_response(final_response)

View File

@ -0,0 +1,199 @@
from io import BytesIO
from typing import Optional
import torch
from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
sync_op_raw,
upload_images_to_comfyapi,
validate_string,
)
MODELS_MAP = {
"LTX-2 (Pro)": "ltx-2-pro",
"LTX-2 (Fast)": "ltx-2-fast",
}
class ExecuteTaskRequest(BaseModel):
prompt: str = Field(...)
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
class TextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
data=ExecuteTaskRequest(
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class ImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
IO.String.Input(
"prompt",
multiline=True,
default="",
),
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input(
"resolution",
options=[
"1920x1080",
"2560x1440",
"3840x2160",
],
),
IO.Combo.Input("fps", options=[25, 50], default=25),
IO.Boolean.Input(
"generate_audio",
default=False,
optional=True,
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
image: torch.Tensor,
model: str,
prompt: str,
duration: int,
resolution: str,
fps: int = 25,
generate_audio: bool = False,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw(
cls,
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
data=ExecuteTaskRequest(
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
prompt=prompt,
model=MODELS_MAP[model],
duration=duration,
resolution=resolution,
fps=fps,
generate_audio=generate_audio,
),
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
class LtxvApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TextToVideoNode,
ImageToVideoNode,
]
async def comfy_entrypoint() -> LtxvApiExtension:
return LtxvApiExtension()

View File

@ -1,69 +1,51 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional from typing import Optional
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.luma_api import ( from comfy_api_nodes.apis.luma_api import (
LumaImageModel,
LumaVideoModel,
LumaVideoOutputResolution,
LumaVideoModelOutputDuration,
LumaAspectRatio, LumaAspectRatio,
LumaState,
LumaImageGenerationRequest,
LumaGenerationRequest,
LumaGeneration,
LumaCharacterRef, LumaCharacterRef,
LumaModifyImageRef, LumaConceptChain,
LumaGeneration,
LumaGenerationRequest,
LumaImageGenerationRequest,
LumaImageIdentity, LumaImageIdentity,
LumaImageModel,
LumaImageReference,
LumaIO,
LumaKeyframes,
LumaModifyImageRef,
LumaReference, LumaReference,
LumaReferenceChain, LumaReferenceChain,
LumaImageReference, LumaVideoModel,
LumaKeyframes, LumaVideoModelOutputDuration,
LumaConceptChain, LumaVideoOutputResolution,
LumaIO,
get_luma_concepts, get_luma_concepts,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_image_tensor,
SynchronousOperation, download_url_to_video_output,
PollingOperation, poll_op,
EmptyRequest, sync_op,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi, upload_images_to_comfyapi,
process_image_response,
validate_string, validate_string,
) )
from server import PromptServer
import aiohttp
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105 LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100 LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(IO.ComfyNode): class LumaReferenceNode(IO.ComfyNode):
"""
Holds an image and weight for use with Luma Generate Image node.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaReferenceNode", node_id="LumaReferenceNode",
display_name="Luma Reference", display_name="Luma Reference",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode):
), ),
], ],
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
) )
@classmethod @classmethod
def execute( def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
) -> IO.NodeOutput:
if luma_ref is not None: if luma_ref is not None:
luma_ref = luma_ref.clone() luma_ref = luma_ref.clone()
else: else:
@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode):
class LumaConceptsNode(IO.ComfyNode): class LumaConceptsNode(IO.ComfyNode):
"""
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaConceptsNode", node_id="LumaConceptsNode",
display_name="Luma Concepts", display_name="Luma Concepts",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"concept1", "concept1",
@ -138,11 +109,6 @@ class LumaConceptsNode(IO.ComfyNode):
), ),
], ],
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
) )
@classmethod @classmethod
@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode):
class LumaImageGenerationNode(IO.ComfyNode): class LumaImageGenerationNode(IO.ComfyNode):
"""
Generates images synchronously based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageNode", node_id="LumaImageNode",
display_name="Luma Text to Image", display_name="Luma Text to Image",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -237,45 +199,30 @@ class LumaImageGenerationNode(IO.ComfyNode):
aspect_ratio: str, aspect_ratio: str,
seed, seed,
style_image_weight: float, style_image_weight: float,
image_luma_ref: LumaReferenceChain = None, image_luma_ref: Optional[LumaReferenceChain] = None,
style_image: torch.Tensor = None, style_image: Optional[torch.Tensor] = None,
character_image: torch.Tensor = None, character_image: Optional[torch.Tensor] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=3) validate_string(prompt, strip_whitespace=True, min_length=3)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# handle image_luma_ref # handle image_luma_ref
api_image_ref = None api_image_ref = None
if image_luma_ref is not None: if image_luma_ref is not None:
api_image_ref = await cls._convert_luma_refs( api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
)
# handle style_luma_ref # handle style_luma_ref
api_style_ref = None api_style_ref = None
if style_image is not None: if style_image is not None:
api_style_ref = await cls._convert_style_image( api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
)
# handle character_ref images # handle character_ref images
character_ref = None character_ref = None
if character_image is not None: if character_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
character_image, max_images=4, auth_kwargs=auth_kwargs, character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
)
character_ref = LumaCharacterRef(
identity0=LumaImageIdentity(images=download_urls)
)
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/luma/generations/image", ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
method=HttpMethod.POST, response_model=LumaGeneration,
request_model=LumaImageGenerationRequest, data=LumaImageGenerationRequest(
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode):
style_ref=api_style_ref, style_ref=api_style_ref,
character_ref=character_ref, character_ref=character_ref,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
operation = PollingOperation( ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
poll_endpoint=ApiEndpoint( response_model=LumaGeneration,
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
@classmethod @classmethod
async def _convert_luma_refs( async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
):
luma_urls = [] luma_urls = []
ref_count = 0 ref_count = 0
for ref in luma_ref.refs: for ref in luma_ref.refs:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
ref.image, max_images=1, auth_kwargs=auth_kwargs
)
luma_urls.append(download_urls[0]) luma_urls.append(download_urls[0])
ref_count += 1 ref_count += 1
if ref_count >= max_refs: if ref_count >= max_refs:
@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode):
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
@classmethod @classmethod
async def _convert_style_image( async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
): return await cls._convert_luma_refs(chain, max_refs=1)
chain = LumaReferenceChain(
first_ref=LumaReference(image=style_image, weight=weight)
)
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
class LumaImageModifyNode(IO.ComfyNode): class LumaImageModifyNode(IO.ComfyNode):
"""
Modifies images synchronously based on prompt and aspect ratio.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageModifyNode", node_id="LumaImageModifyNode",
display_name="Luma Image to Image", display_name="Luma Image to Image",
category="api node/image/Luma", category="api node/image/Luma",
description=cleandoc(cls.__doc__ or ""), description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -395,68 +314,37 @@ class LumaImageModifyNode(IO.ComfyNode):
image_weight: float, image_weight: float,
seed, seed,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth_kwargs = { download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# first, upload image
download_urls = await upload_images_to_comfyapi(
image, max_images=1, auth_kwargs=auth_kwargs,
)
image_url = download_urls[0] image_url = download_urls[0]
# next, make Luma call with download url provided response_api = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
path="/proxy/luma/generations/image", response_model=LumaGeneration,
method=HttpMethod.POST, data=LumaImageGenerationRequest(
request_model=LumaImageGenerationRequest,
response_model=LumaGeneration,
),
request=LumaImageGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
modify_image_ref=LumaModifyImageRef( modify_image_ref=LumaModifyImageRef(
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
), ),
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
operation = PollingOperation( ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
poll_endpoint=ApiEndpoint( response_model=LumaGeneration,
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=cls.hidden.unique_id,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.image) as img_response:
img = process_image_response(await img_response.content.read())
return IO.NodeOutput(img)
class LumaTextToVideoGenerationNode(IO.ComfyNode): class LumaTextToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaVideoNode", node_id="LumaVideoNode",
display_name="Luma Text to Video", display_name="Luma Text to Video",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
"luma_concepts", "luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True, optional=True,
) ),
], ],
outputs=[IO.Video.Output()], outputs=[IO.Video.Output()],
hidden=[ hidden=[
@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
duration: str, duration: str,
loop: bool, loop: bool,
seed, seed,
luma_concepts: LumaConceptChain = None, luma_concepts: Optional[LumaConceptChain] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=3) validate_string(prompt, strip_whitespace=False, min_length=3)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
auth_kwargs = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/luma/generations", method="POST"),
} response_model=LumaGeneration,
operation = SynchronousOperation( data=LumaGenerationRequest(
endpoint=ApiEndpoint(
path="/proxy/luma/generations",
method=HttpMethod.POST,
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
resolution=resolution, resolution=resolution,
@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
loop=loop, loop=loop,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
if cls.hidden.unique_id: ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) response_model=LumaGeneration,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION, estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class LumaImageToVideoGenerationNode(IO.ComfyNode): class LumaImageToVideoGenerationNode(IO.ComfyNode):
"""
Generates videos synchronously based on prompt, input images, and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="LumaImageToVideoNode", node_id="LumaImageToVideoNode",
display_name="Luma Image to Video", display_name="Luma Image to Video",
category="api node/video/Luma", category="api node/video/Luma",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
"luma_concepts", "luma_concepts",
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
optional=True, optional=True,
) ),
], ],
outputs=[IO.Video.Output()], outputs=[IO.Video.Output()],
hidden=[ hidden=[
@ -662,25 +521,15 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
luma_concepts: LumaConceptChain = None, luma_concepts: LumaConceptChain = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if first_image is None and last_image is None: if first_image is None and last_image is None:
raise Exception( raise Exception("At least one of first_image and last_image requires an input.")
"At least one of first_image and last_image requires an input." keyframes = await cls._convert_to_keyframes(first_image, last_image)
)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
duration = duration if model != LumaVideoModel.ray_1_6 else None duration = duration if model != LumaVideoModel.ray_1_6 else None
resolution = resolution if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None
response_api = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/luma/generations", method="POST"),
path="/proxy/luma/generations", response_model=LumaGeneration,
method=HttpMethod.POST, data=LumaGenerationRequest(
request_model=LumaGenerationRequest,
response_model=LumaGeneration,
),
request=LumaGenerationRequest(
prompt=prompt, prompt=prompt,
model=model, model=model,
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
keyframes=keyframes, keyframes=keyframes,
concepts=luma_concepts.create_api_model() if luma_concepts else None, concepts=luma_concepts.create_api_model() if luma_concepts else None,
), ),
auth_kwargs=auth_kwargs,
) )
response_api: LumaGeneration = await operation.execute() response_poll = await poll_op(
cls,
if cls.hidden.unique_id: poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) response_model=LumaGeneration,
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=LumaGeneration,
),
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state, status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=cls.hidden.unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION, estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=auth_kwargs,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.assets.video) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
@classmethod @classmethod
async def _convert_to_keyframes( async def _convert_to_keyframes(
cls, cls,
first_image: torch.Tensor = None, first_image: torch.Tensor = None,
last_image: torch.Tensor = None, last_image: torch.Tensor = None,
auth_kwargs: Optional[dict[str,str]] = None,
): ):
if first_image is None and last_image is None: if first_image is None and last_image is None:
return None return None
frame0 = None frame0 = None
frame1 = None frame1 = None
if first_image is not None: if first_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
first_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame0 = LumaImageReference(type="image", url=download_urls[0]) frame0 = LumaImageReference(type="image", url=download_urls[0])
if last_image is not None: if last_image is not None:
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
last_image, max_images=1, auth_kwargs=auth_kwargs,
)
frame1 = LumaImageReference(type="image", url=download_urls[0]) frame1 = LumaImageReference(type="image", url=download_urls[0])
return LumaKeyframes(frame0=frame0, frame1=frame1) return LumaKeyframes(frame0=frame0, frame1=frame1)

View File

@ -1,71 +1,57 @@
from inspect import cleandoc
from typing import Optional from typing import Optional
import logging
import torch
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis.minimax_api import (
MinimaxFileRetrieveResponse,
MiniMaxModel,
MinimaxTaskResultResponse,
MinimaxVideoGenerationRequest, MinimaxVideoGenerationRequest,
MinimaxVideoGenerationResponse, MinimaxVideoGenerationResponse,
MinimaxFileRetrieveResponse,
MinimaxTaskResultResponse,
SubjectReferenceItem, SubjectReferenceItem,
MiniMaxModel,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_bytesio,
upload_images_to_comfyapi, upload_images_to_comfyapi,
validate_string, validate_string,
) )
from server import PromptServer
I2V_AVERAGE_DURATION = 114 I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234 T2V_AVERAGE_DURATION = 234
async def _generate_mm_video( async def _generate_mm_video(
cls: type[IO.ComfyNode],
*, *,
auth: dict[str, str],
node_id: str,
prompt_text: str, prompt_text: str,
seed: int, seed: int,
model: str, model: str,
image: Optional[torch.Tensor] = None, # used for ImageToVideo image: Optional[torch.Tensor] = None, # used for ImageToVideo
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
average_duration: Optional[int] = None, average_duration: Optional[int] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if image is None: if image is None:
validate_string(prompt_text, field_name="prompt_text") validate_string(prompt_text, field_name="prompt_text")
# upload image, if passed in
image_url = None image_url = None
if image is not None: if image is not None:
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
subject_reference = None subject_reference = None
if subject is not None: if subject is not None:
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
subject_reference = [SubjectReferenceItem(image=subject_url)] subject_reference = [SubjectReferenceItem(image=subject_url)]
response = await sync_op(
video_generate_operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
path="/proxy/minimax/video_generation", response_model=MinimaxVideoGenerationResponse,
method=HttpMethod.POST, data=MinimaxVideoGenerationRequest(
request_model=MinimaxVideoGenerationRequest,
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model), model=MiniMaxModel(model),
prompt=prompt_text, prompt=prompt_text,
callback_url=None, callback_url=None,
@ -73,81 +59,50 @@ async def _generate_mm_video(
subject_reference=subject_reference, subject_reference=subject_reference,
prompt_optimizer=None, prompt_optimizer=None,
), ),
auth_kwargs=auth,
) )
response = await video_generate_operation.execute()
task_id = response.task_id task_id = response.task_id
if not task_id: if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}") raise Exception(f"MiniMax generation failed: {response.base_resp}")
video_generate_operation = PollingOperation( task_result = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path="/proxy/minimax/query/video_generation", ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
method=HttpMethod.GET, response_model=MinimaxTaskResultResponse,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value, status_extractor=lambda x: x.status.value,
estimated_duration=average_duration, estimated_duration=average_duration,
node_id=node_id,
auth_kwargs=auth,
) )
task_result = await video_generate_operation.execute()
file_id = task_result.file_id file_id = task_result.file_id
if file_id is None: if file_id is None:
raise Exception("Request was not successful. Missing file ID.") raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation( file_result = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/files/retrieve", ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
method=HttpMethod.GET, response_model=MinimaxFileRetrieveResponse,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
) )
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url file_url = file_result.file.download_url
if file_url is None: if file_url is None:
raise Exception( raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
f"No video was found in the response. Full response: {file_result.model_dump()}" if file_result.file.backup_download_url:
) try:
logging.info("Generated video URL: %s", file_url) return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
if node_id: except Exception: # if we have a second URL to retrieve the result, try again using that one
if hasattr(file_result.file, "backup_download_url"): return IO.NodeOutput(
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
else: )
message = f"Result URL: {file_url}" return IO.NodeOutput(await download_url_to_video_output(file_url))
PromptServer.instance.send_progress_text(message, node_id)
# Download and return as VideoFromFile
video_io = await download_url_to_bytesio(file_url)
if video_io is None:
error_msg = f"Failed to download video from {file_url}"
logging.error(error_msg)
raise Exception(error_msg)
return IO.NodeOutput(VideoFromFile(video_io))
class MinimaxTextToVideoNode(IO.ComfyNode): class MinimaxTextToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxTextToVideoNode", node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video", display_name="MiniMax Text to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt_text", "prompt_text",
@ -189,11 +144,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
class MinimaxImageToVideoNode(IO.ComfyNode): class MinimaxImageToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxImageToVideoNode", node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video", display_name="MiniMax Image to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"image", "image",
@ -261,11 +208,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
class MinimaxSubjectToVideoNode(IO.ComfyNode): class MinimaxSubjectToVideoNode(IO.ComfyNode):
"""
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxSubjectToVideoNode", node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video", display_name="MiniMax Subject to Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[ inputs=[
IO.Image.Input( IO.Image.Input(
"subject", "subject",
@ -333,11 +272,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
seed: int = 0, seed: int = 0,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
return await _generate_mm_video( return await _generate_mm_video(
auth={ cls,
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
node_id=cls.hidden.unique_id,
prompt_text=prompt_text, prompt_text=prompt_text,
seed=seed, seed=seed,
model=model, model=model,
@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
class MinimaxHailuoVideoNode(IO.ComfyNode): class MinimaxHailuoVideoNode(IO.ComfyNode):
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="MinimaxHailuoVideoNode", node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video", display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax", category="api node/video/MiniMax",
description=cleandoc(cls.__doc__ or ""), description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt_text", "prompt_text",
@ -420,10 +353,6 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
resolution: str = "768P", resolution: str = "768P",
model: str = "MiniMax-Hailuo-02", model: str = "MiniMax-Hailuo-02",
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
if first_frame_image is None: if first_frame_image is None:
validate_string(prompt_text, field_name="prompt_text") validate_string(prompt_text, field_name="prompt_text")
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
# upload image, if passed in # upload image, if passed in
image_url = None image_url = None
if first_frame_image is not None: if first_frame_image is not None:
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
video_generate_operation = SynchronousOperation( response = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/video_generation", ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
method=HttpMethod.POST, response_model=MinimaxVideoGenerationResponse,
request_model=MinimaxVideoGenerationRequest, data=MinimaxVideoGenerationRequest(
response_model=MinimaxVideoGenerationResponse,
),
request=MinimaxVideoGenerationRequest(
model=MiniMaxModel(model), model=MiniMaxModel(model),
prompt=prompt_text, prompt=prompt_text,
callback_url=None, callback_url=None,
@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
duration=duration, duration=duration,
resolution=resolution, resolution=resolution,
), ),
auth_kwargs=auth,
) )
response = await video_generate_operation.execute()
task_id = response.task_id task_id = response.task_id
if not task_id: if not task_id:
raise Exception(f"MiniMax generation failed: {response.base_resp}") raise Exception(f"MiniMax generation failed: {response.base_resp}")
average_duration = 120 if resolution == "768P" else 240 average_duration = 120 if resolution == "768P" else 240
video_generate_operation = PollingOperation( task_result = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path="/proxy/minimax/query/video_generation", ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
method=HttpMethod.GET, response_model=MinimaxTaskResultResponse,
request_model=EmptyRequest,
response_model=MinimaxTaskResultResponse,
query_params={"task_id": task_id},
),
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value, status_extractor=lambda x: x.status.value,
estimated_duration=average_duration, estimated_duration=average_duration,
node_id=cls.hidden.unique_id,
auth_kwargs=auth,
) )
task_result = await video_generate_operation.execute()
file_id = task_result.file_id file_id = task_result.file_id
if file_id is None: if file_id is None:
raise Exception("Request was not successful. Missing file ID.") raise Exception("Request was not successful. Missing file ID.")
file_retrieve_operation = SynchronousOperation( file_result = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/minimax/files/retrieve", ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
method=HttpMethod.GET, response_model=MinimaxFileRetrieveResponse,
request_model=EmptyRequest,
response_model=MinimaxFileRetrieveResponse,
query_params={"file_id": int(file_id)},
),
request=EmptyRequest(),
auth_kwargs=auth,
) )
file_result = await file_retrieve_operation.execute()
file_url = file_result.file.download_url file_url = file_result.file.download_url
if file_url is None: if file_url is None:
raise Exception( raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info("Generated video URL: %s", file_url)
if cls.hidden.unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
video_io = await download_url_to_bytesio(file_url) if file_result.file.backup_download_url:
if video_io is None: try:
error_msg = f"Failed to download video from {file_url}" return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
logging.error(error_msg) except Exception: # if we have a second URL to retrieve the result, try again using that one
raise Exception(error_msg) return IO.NodeOutput(
return IO.NodeOutput(VideoFromFile(video_io)) await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
)
return IO.NodeOutput(await download_url_to_video_output(file_url))
class MinimaxExtension(ComfyExtension): class MinimaxExtension(ComfyExtension):

View File

@ -1,35 +1,31 @@
import logging import logging
from typing import Any, Callable, Optional, TypeVar from typing import Optional
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
MoonvalleyTextToVideoRequest, MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams, MoonvalleyTextToVideoInferenceParams,
MoonvalleyTextToVideoRequest,
MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoInferenceParams,
MoonvalleyVideoToVideoRequest, MoonvalleyVideoToVideoRequest,
MoonvalleyPromptResponse,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output, download_url_to_video_output,
poll_op,
sync_op,
trim_video,
upload_images_to_comfyapi, upload_images_to_comfyapi,
upload_video_to_comfyapi, upload_video_to_comfyapi,
validate_container_format_is_mp4, validate_container_format_is_mp4,
validate_image_dimensions,
validate_string,
) )
from comfy_api.input import VideoInput
from comfy_api.latest import ComfyExtension, InputImpl, IO
import av
import io
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video"
@ -51,13 +47,6 @@ MAX_VID_HEIGHT = 10000
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
R = TypeVar("R")
class MoonvalleyApiError(Exception):
"""Base exception for Moonvalley API errors."""
pass
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
@ -69,64 +58,7 @@ def validate_task_creation_response(response) -> None:
if not is_valid_task_creation_response(response): if not is_valid_task_creation_response(response):
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
logging.error(error_msg) logging.error(error_msg)
raise MoonvalleyApiError(error_msg) raise RuntimeError(error_msg)
def get_video_from_response(response):
video = response.output_url
logging.info(
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
)
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Moonvalley video generation task result.
Will not raise an error if the response is not valid.
"""
if response:
return str(get_video_from_response(response))
else:
return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
"completed",
],
max_poll_attempts=240, # 64 minutes with 16s interval
poll_interval=16.0,
failed_statuses=["error"],
status_extractor=lambda response: (
response.status if response and response.status else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
node_id=node_id,
).execute()
def validate_prompts(
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
):
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
if not prompt:
raise ValueError("Positive prompt is empty")
if len(prompt) > max_length:
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
if negative_prompt and len(negative_prompt) > max_length:
raise ValueError(
f"Negative prompt is too long: {len(negative_prompt)} characters"
)
return True
def validate_video_to_video_input(video: VideoInput) -> VideoInput: def validate_video_to_video_input(video: VideoInput) -> VideoInput:
@ -170,12 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
} }
if (width, height) not in supported_resolutions: if (width, height) not in supported_resolutions:
supported_list = ", ".join( supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
[f"{w}x{h}" for w, h in sorted(supported_resolutions)] raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
)
raise ValueError(
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
)
def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
@ -188,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_minimum_duration(duration: float) -> None: def _validate_minimum_duration(duration: float) -> None:
"""Ensures video is at least 5 seconds long.""" """Ensures video is at least 5 seconds long."""
if duration < 5: if duration < 5:
raise MoonvalleyApiError("Input video must be at least 5 seconds long.") raise ValueError("Input video must be at least 5 seconds long.")
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
@ -198,123 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
return video return video
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = io.BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream(
"h264", rate=stream.average_rate
)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info(
"Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate
)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream(
"aac", rate=stream.sample_rate
)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (
estimated_frames // 16
) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def parse_width_height_from_res(resolution: str): def parse_width_height_from_res(resolution: str):
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
res_map = { res_map = {
@ -338,19 +149,14 @@ def parse_control_parameter(value):
return control_map.get(value, control_map["Motion Transfer"]) return control_map.get(value, control_map["Motion Transfer"])
async def get_response( async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None return await poll_op(
) -> MoonvalleyPromptResponse: cls,
return await poll_until_finished( ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
auth_kwargs, response_model=MoonvalleyPromptResponse,
ApiEndpoint( status_extractor=lambda r: (r.status if r and r.status else None),
path=f"{API_PROMPTS_ENDPOINT}/{task_id}", poll_interval=16.0,
method=HttpMethod.GET, max_poll_attempts=240,
request_model=EmptyRequest,
response_model=MoonvalleyPromptResponse,
),
result_url_extractor=get_video_url_from_response,
node_id=node_id,
) )
@ -444,14 +250,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
steps: int, steps: int,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution) width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams( inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
steps=steps, steps=steps,
@ -464,33 +266,17 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
# Get MIME type from tensor - assuming PNG format for image tensors # Get MIME type from tensor - assuming PNG format for image tensors
mime_type = "image/png" mime_type = "image/png"
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
image_url = ( task_creation_response = await sync_op(
await upload_images_to_comfyapi( cls,
image, max_images=1, auth_kwargs=auth, mime_type=mime_type endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
) response_model=MoonvalleyPromptResponse,
)[0] data=MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
request = MoonvalleyTextToVideoRequest(
image_url=image_url, prompt_text=prompt, inference_params=inference_params
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_IMG2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyTextToVideoRequest,
response_model=MoonvalleyPromptResponse,
), ),
request=request,
auth_kwargs=auth,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id final_response = await get_response(cls, task_creation_response.id)
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url) video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video) return IO.NodeOutput(video)
@ -582,15 +368,10 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
steps=33, steps=33,
prompt_adherence=4.5, prompt_adherence=4.5,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
validated_video = validate_video_to_video_input(video) validated_video = validate_video_to_video_input(video)
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) video_url = await upload_video_to_comfyapi(cls, validated_video)
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_prompts(prompt, negative_prompt) validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
# Only include motion_intensity for Motion Transfer # Only include motion_intensity for Motion Transfer
control_params = {} control_params = {}
@ -605,35 +386,20 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
guidance_scale=prompt_adherence, guidance_scale=prompt_adherence,
) )
control = parse_control_parameter(control_type) task_creation_response = await sync_op(
cls,
request = MoonvalleyVideoToVideoRequest( endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
control_type=control, response_model=MoonvalleyPromptResponse,
video_url=video_url, data=MoonvalleyVideoToVideoRequest(
prompt_text=prompt, control_type=parse_control_parameter(control_type),
inference_params=inference_params, video_url=video_url,
) prompt_text=prompt,
inference_params=inference_params,
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=API_VIDEO2VIDEO_ENDPOINT,
method=HttpMethod.POST,
request_model=MoonvalleyVideoToVideoRequest,
response_model=MoonvalleyPromptResponse,
), ),
request=request,
auth_kwargs=auth,
) )
task_creation_response = await initial_operation.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
class MoonvalleyTxt2VideoNode(IO.ComfyNode): class MoonvalleyTxt2VideoNode(IO.ComfyNode):
@ -720,14 +486,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
seed: int, seed: int,
steps: int, steps: int,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
width_height = parse_width_height_from_res(resolution) width_height = parse_width_height_from_res(resolution)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
inference_params = MoonvalleyTextToVideoInferenceParams( inference_params = MoonvalleyTextToVideoInferenceParams(
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
steps=steps, steps=steps,
@ -737,30 +499,16 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
width=width_height["width"], width=width_height["width"],
height=width_height["height"], height=width_height["height"],
) )
request = MoonvalleyTextToVideoRequest(
prompt_text=prompt, inference_params=inference_params
)
init_op = SynchronousOperation( task_creation_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=API_TXT2VIDEO_ENDPOINT, endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
method=HttpMethod.POST, response_model=MoonvalleyPromptResponse,
request_model=MoonvalleyTextToVideoRequest, data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
response_model=MoonvalleyPromptResponse,
),
request=request,
auth_kwargs=auth,
) )
task_creation_response = await init_op.execute()
validate_task_creation_response(task_creation_response) validate_task_creation_response(task_creation_response)
task_id = task_creation_response.id final_response = await get_response(cls, task_creation_response.id)
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
final_response = await get_response(
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
)
video = await download_url_to_video_output(final_response.output_url)
return IO.NodeOutput(video)
class MoonvalleyExtension(ComfyExtension): class MoonvalleyExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@ -7,28 +7,23 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import logging import logging
from typing import Optional, TypeVar from typing import Optional
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import ( from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output, download_url_to_video_output,
tensor_to_bytesio, tensor_to_bytesio,
validate_string,
)
from comfy_api_nodes.apis import pika_defs
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
EmptyRequest, sync_op,
HttpMethod, poll_op,
PollingOperation,
SynchronousOperation,
) )
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@ -44,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task( async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse], task_id: str,
auth_kwargs: Optional[dict[str, str]] = None, cls: type[IO.ComfyNode],
node_id: Optional[str] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
task_id = (await initial_operation.execute()).video_id final_response: pika_defs.PikaVideoResponse = await poll_op(
final_response: pika_defs.PikaVideoResponse = await PollingOperation( cls,
poll_endpoint=ApiEndpoint( ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
path=f"{PATH_VIDEO_GET}/{task_id}", response_model=pika_defs.PikaVideoResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (response.status.value if response.status else None), status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60, estimated_duration=60,
max_poll_attempts=240, max_poll_attempts=240,
).execute() )
if not final_response.url: if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg) logging.error(error_msg)
@ -128,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
resolution=resolution, resolution=resolution,
duration=duration, duration=duration,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode): class PikaTextToVideoNode(IO.ComfyNode):
@ -187,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration: int, duration: int,
aspect_ratio: float, aspect_ratio: float,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -206,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration=duration, duration=duration,
aspectRatio=aspect_ratio, aspectRatio=aspect_ratio,
), ),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded", content_type="application/x-www-form-urlencoded",
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode): class PikaScenes(IO.ComfyNode):
@ -313,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
duration=duration, duration=duration,
aspectRatio=aspect_ratio, aspectRatio=aspect_ratio,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode): class PikAdditionsNode(IO.ComfyNode):
@ -387,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode): class PikaSwapsNode(IO.ComfyNode):
@ -476,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
seed=seed, seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None, modifyRegionRoi=region_to_modify if region_to_modify else None,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode): class PikaffectsNode(IO.ComfyNode):
@ -532,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect, pikaffect=pikaffect,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -551,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
), ),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode): class PikaStartEndFrameNode(IO.ComfyNode):
@ -596,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
] ]
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -616,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
), ),
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension): class PikaApiNodesExtension(ComfyExtension):

View File

@ -1,7 +1,6 @@
from inspect import cleandoc import torch
from typing import Optional
from typing_extensions import override from typing_extensions import override
from io import BytesIO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.pixverse_api import ( from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest, PixverseTextVideoRequest,
PixverseImageVideoRequest, PixverseImageVideoRequest,
@ -17,59 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import (
PixverseIO, PixverseIO,
pixverse_templates, pixverse_templates,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio, tensor_to_bytesio,
validate_string, validate_string,
) )
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO
import torch
import aiohttp
AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52 AVERAGE_DURATION_T2T = 52
def get_video_url_from_response( async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
response: PixverseGenerationStatusResponse, response_upload = await sync_op(
) -> Optional[str]: cls,
if response.Resp is None or response.Resp.url is None: ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
return None response_model=PixverseImageUploadResponse,
return str(response.Resp.url) files={"image": tensor_to_bytesio(image)},
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=PixverseImageUploadResponse,
),
request=EmptyRequest(),
files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
response_upload: PixverseImageUploadResponse = await operation.execute()
if response_upload.Resp is None: if response_upload.Resp is None:
raise Exception( raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
return response_upload.Resp.img_id return response_upload.Resp.img_id
@ -95,22 +65,17 @@ class PixverseTemplateNode(IO.ComfyNode):
template_id = pixverse_templates.get(template, None) template_id = pixverse_templates.get(template, None)
if template_id is None: if template_id is None:
raise Exception(f"Template '{template}' is not recognized.") raise Exception(f"Template '{template}' is not recognized.")
# just return the integer
return IO.NodeOutput(template_id) return IO.NodeOutput(template_id)
class PixverseTextToVideoNode(IO.ComfyNode): class PixverseTextToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTextToVideoNode", node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video", display_name="PixVerse Text to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -177,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False, min_length=1)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
if quality == PixverseQuality.res_1080p: if quality == PixverseQuality.res_1080p:
@ -186,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
} response_model=PixverseVideoResponse,
operation = SynchronousOperation( data=PixverseTextVideoRequest(
endpoint=ApiEndpoint(
path="/proxy/pixverse/video/text/generate",
method=HttpMethod.POST,
request_model=PixverseTextVideoRequest,
response_model=PixverseVideoResponse,
),
request=PixverseTextVideoRequest(
prompt=prompt, prompt=prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
quality=quality, quality=quality,
@ -207,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -228,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseImageToVideoNode(IO.ComfyNode): class PixverseImageToVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseImageToVideoNode", node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video", display_name="PixVerse Image to Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.String.Input( IO.String.Input(
@ -316,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
pixverse_template: int = None, pixverse_template: int = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { img_id = await upload_image_to_pixverse(cls, image)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -330,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/img/generate", ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseImageVideoRequest, data=PixverseImageVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseImageVideoRequest(
img_id=img_id, img_id=img_id,
prompt=prompt, prompt=prompt,
quality=quality, quality=quality,
@ -347,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode):
template_id=pixverse_template, template_id=pixverse_template,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -368,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V, estimated_duration=AVERAGE_DURATION_I2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixverseTransitionVideoNode(IO.ComfyNode): class PixverseTransitionVideoNode(IO.ComfyNode):
"""
Generates videos based on prompt and output_size.
"""
@classmethod @classmethod
def define_schema(cls) -> IO.Schema: def define_schema(cls) -> IO.Schema:
return IO.Schema( return IO.Schema(
node_id="PixverseTransitionVideoNode", node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video", display_name="PixVerse Transition Video",
category="api node/video/PixVerse", category="api node/video/PixVerse",
description=cleandoc(cls.__doc__ or ""), description="Generates videos based on prompt and output_size.",
inputs=[ inputs=[
IO.Image.Input("first_frame"), IO.Image.Input("first_frame"),
IO.Image.Input("last_frame"), IO.Image.Input("last_frame"),
@ -452,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt: str = None, negative_prompt: str = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False) validate_string(prompt, strip_whitespace=False)
auth = { first_frame_id = await upload_image_to_pixverse(cls, first_frame)
"auth_token": cls.hidden.auth_token_comfy_org, last_frame_id = await upload_image_to_pixverse(cls, last_frame)
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
# 1080p is limited to 5 seconds duration # 1080p is limited to 5 seconds duration
# only normal motion_mode supported for 1080p or for non-5 second duration # only normal motion_mode supported for 1080p or for non-5 second duration
@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
elif duration_seconds != PixverseDuration.dur_5: elif duration_seconds != PixverseDuration.dur_5:
motion_mode = PixverseMotionMode.normal motion_mode = PixverseMotionMode.normal
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/pixverse/video/transition/generate", ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
method=HttpMethod.POST, response_model=PixverseVideoResponse,
request_model=PixverseTransitionVideoRequest, data=PixverseTransitionVideoRequest(
response_model=PixverseVideoResponse,
),
request=PixverseTransitionVideoRequest(
first_frame_img=first_frame_id, first_frame_img=first_frame_id,
last_frame_img=last_frame_id, last_frame_img=last_frame_id,
prompt=prompt, prompt=prompt,
@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
seed=seed, seed=seed,
), ),
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.Resp is None: if response_api.Resp is None:
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
method=HttpMethod.GET, response_model=PixverseGenerationStatusResponse,
request_model=EmptyRequest,
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful], completed_statuses=[PixverseStatus.successful],
failed_statuses=[ failed_statuses=[
PixverseStatus.contents_moderation, PixverseStatus.contents_moderation,
@ -505,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
PixverseStatus.deleted, PixverseStatus.deleted,
], ],
status_extractor=lambda x: x.Resp.status, status_extractor=lambda x: x.Resp.status,
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V, estimated_duration=AVERAGE_DURATION_T2V,
) )
response_poll = await operation.execute() return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
async with aiohttp.ClientSession() as session:
async with session.get(response_poll.Resp.url) as vid_response:
return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
class PixVerseExtension(ComfyExtension): class PixVerseExtension(ComfyExtension):

File diff suppressed because it is too large Load Diff

View File

@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/
""" """
from __future__ import annotations
from inspect import cleandoc from inspect import cleandoc
import folder_paths as comfy_paths import folder_paths as comfy_paths
import aiohttp
import os import os
import asyncio
import logging import logging
import math import math
from typing import Optional from typing import Optional
@ -26,11 +23,11 @@ from comfy_api_nodes.apis.rodin_api import (
Rodin3DDownloadResponse, Rodin3DDownloadResponse,
JobStatus, JobStatus,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
sync_op,
poll_op,
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_bytesio,
SynchronousOperation,
PollingOperation,
) )
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO
@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
async def create_generate_task( async def create_generate_task(
cls: type[IO.ComfyNode],
images=None, images=None,
seed=1, seed=1,
material="PBR", material="PBR",
quality_override=18000, quality_override=18000,
tier="Regular", tier="Regular",
mesh_mode="Quad", mesh_mode="Quad",
TAPose = False, ta_pose: bool = False,
auth_kwargs: Optional[dict[str, str]] = None,
): ):
if images is None: if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.") raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) > 5: if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.") raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin" response = await sync_op(
operation = SynchronousOperation( cls,
endpoint=ApiEndpoint( ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
path=path, response_model=Rodin3DGenerateResponse,
method=HttpMethod.POST, data=Rodin3DGenerateRequest(
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed, seed=seed,
tier=tier, tier=tier,
material=material, material=material,
quality_override=quality_override, quality_override=quality_override,
mesh_mode=mesh_mode, mesh_mode=mesh_mode,
TAPose=TAPose, TAPose=ta_pose,
), ),
files=[ files=[
( (
@ -159,11 +152,8 @@ async def create_generate_task(
for image in images if image is not None for image in images if image is not None
], ],
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth_kwargs,
) )
response = await operation.execute()
if hasattr(response, "error"): if hasattr(response, "error"):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message) logging.error(error_message)
@ -187,75 +177,46 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
return "DONE" return "DONE"
return "Generating" return "Generating"
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
if not response.jobs:
return None
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
return int((completed_count / len(response.jobs)) * 100)
async def poll_for_task_status(
subscription_key, auth_kwargs: Optional[dict[str, str]] = None, async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
) -> Rodin3DCheckStatusResponse:
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path="/proxy/rodin/api/v2/status",
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=check_rodin_status,
poll_interval=3.0,
auth_kwargs=auth_kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return await poll_operation.execute() return await poll_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: response_model=Rodin3DCheckStatusResponse,
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
operation = SynchronousOperation( status_extractor=check_rodin_status,
endpoint=ApiEndpoint( progress_extractor=extract_progress,
path="/proxy/rodin/api/v2/download",
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(task_uuid=uuid),
auth_kwargs=auth_kwargs,
) )
return await operation.execute()
async def download_files(url_list, task_uuid): async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
return await sync_op(
cls,
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
response_model=Rodin3DDownloadResponse,
data=Rodin3DDownloadRequest(task_uuid=uuid),
monitor_progress=False,
)
async def download_files(url_list, task_uuid: str):
result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
model_file_path = None model_file_path = None
async with aiohttp.ClientSession() as session: for i in url_list.list:
for i in url_list.list: file_path = os.path.join(save_path, i.name)
url = i.url if file_path.endswith(".glb"):
file_name = i.name model_file_path = os.path.join(result_folder_name, i.name)
file_path = os.path.join(save_path, file_name) await download_url_to_bytesio(i.url, file_path)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5
for attempt in range(max_retries):
try:
async with session.get(url) as resp:
resp.raise_for_status()
with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024):
f.write(chunk)
break
except Exception as e:
logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e))
if attempt < max_retries - 1:
logging.info("Retrying...")
await asyncio.sleep(2)
else:
logging.info(
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
file_path,
max_retries,
)
return model_file_path return model_file_path
@ -277,6 +238,7 @@ class Rodin3D_Regular(IO.ComfyNode):
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org, IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
) )
@ -295,21 +257,17 @@ class Rodin3D_Regular(IO.ComfyNode):
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task( task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images, images=m_images,
seed=Seed, seed=Seed,
material=Material_Type, material=Material_Type,
quality_override=quality_override, quality_override=quality_override,
tier=tier, tier=tier,
mesh_mode=mesh_mode, mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
await poll_for_task_status(subscription_key, auth_kwargs=auth) await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model) return IO.NodeOutput(model)
@ -333,6 +291,7 @@ class Rodin3D_Detail(IO.ComfyNode):
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org, IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
) )
@ -351,21 +310,17 @@ class Rodin3D_Detail(IO.ComfyNode):
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task( task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images, images=m_images,
seed=Seed, seed=Seed,
material=Material_Type, material=Material_Type,
quality_override=quality_override, quality_override=quality_override,
tier=tier, tier=tier,
mesh_mode=mesh_mode, mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
await poll_for_task_status(subscription_key, auth_kwargs=auth) await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model) return IO.NodeOutput(model)
@ -389,6 +344,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org, IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
) )
@ -401,27 +357,22 @@ class Rodin3D_Smooth(IO.ComfyNode):
Material_Type, Material_Type,
Polygon_count, Polygon_count,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
tier = "Smooth"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task( task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images, images=m_images,
seed=Seed, seed=Seed,
material=Material_Type, material=Material_Type,
quality_override=quality_override, quality_override=quality_override,
tier=tier, tier="Smooth",
mesh_mode=mesh_mode, mesh_mode=mesh_mode,
auth_kwargs=auth,
) )
await poll_for_task_status(subscription_key, auth_kwargs=auth) await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model) return IO.NodeOutput(model)
@ -452,6 +403,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org, IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
) )
@ -462,29 +414,21 @@ class Rodin3D_Sketch(IO.ComfyNode):
Images, Images,
Seed, Seed,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
tier = "Sketch"
num_images = Images.shape[0] num_images = Images.shape[0]
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
material_type = "PBR"
quality_override = 18000
mesh_mode = "Quad"
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task( task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images, images=m_images,
seed=Seed, seed=Seed,
material=material_type, material="PBR",
quality_override=quality_override, quality_override=18000,
tier=tier, tier="Sketch",
mesh_mode=mesh_mode, mesh_mode="Quad",
auth_kwargs=auth,
) )
await poll_for_task_status(subscription_key, auth_kwargs=auth) await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model) return IO.NodeOutput(model)
@ -523,6 +467,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org, IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
) )
@ -542,22 +487,18 @@ class Rodin3D_Gen2(IO.ComfyNode):
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality_override = get_quality_mode(Polygon_count) mesh_mode, quality_override = get_quality_mode(Polygon_count)
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
task_uuid, subscription_key = await create_generate_task( task_uuid, subscription_key = await create_generate_task(
cls,
images=m_images, images=m_images,
seed=Seed, seed=Seed,
material=Material_Type, material=Material_Type,
quality_override=quality_override, quality_override=quality_override,
tier=tier, tier=tier,
mesh_mode=mesh_mode, mesh_mode=mesh_mode,
TAPose=TAPose, ta_pose=TAPose,
auth_kwargs=auth,
) )
await poll_for_task_status(subscription_key, auth_kwargs=auth) await poll_for_task_status(subscription_key, cls)
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) download_list = await get_rodin_download_list(task_uuid, cls)
model = await download_files(download_list, task_uuid) model = await download_files(download_list, task_uuid)
return IO.NodeOutput(model) return IO.NodeOutput(model)

View File

@ -11,7 +11,7 @@ User Guides:
""" """
from typing import Union, Optional, Any from typing import Union, Optional
from typing_extensions import override from typing_extensions import override
from enum import Enum from enum import Enum
@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
RunwayImageToVideoRequest, RunwayImageToVideoRequest,
RunwayImageToVideoResponse, RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse, RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model, RunwayModelEnum as Model,
RunwayDurationEnum as Duration, RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio, RunwayAspectRatioEnum as AspectRatio,
@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
ReferenceImage, ReferenceImage,
RunwayTextToImageAspectRatioEnum, RunwayTextToImageAspectRatioEnum,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
image_tensor_pair_to_batch, image_tensor_pair_to_batch,
validate_string, validate_string,
validate_image_dimensions,
validate_image_aspect_ratio,
upload_images_to_comfyapi,
download_url_to_video_output,
download_url_to_image_tensor, download_url_to_image_tensor,
ApiEndpoint,
sync_op,
poll_op,
) )
from comfy_api.input_impl import VideoFromFile from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
return None return None
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: response.status.value,
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status( def extract_progress_from_task_status(
response: TaskStatusResponse, response: TaskStatusResponse,
) -> Union[float, None]: ) -> Union[float, None]:
@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
async def get_response( async def get_response(
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
) -> TaskStatusResponse: ) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response.""" """Poll the task status until it is finished then get the response."""
return await poll_until_finished( return await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
path=f"{PATH_GET_TASK_STATUS}/{task_id}", response_model=TaskStatusResponse,
method=HttpMethod.GET, status_extractor=lambda r: r.status.value,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
node_id=node_id, progress_extractor=extract_progress_from_task_status,
) )
async def generate_video( async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest, request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
estimated_duration: Optional[int] = None, estimated_duration: Optional[int] = None,
) -> VideoFromFile: ) -> VideoFromFile:
initial_operation = SynchronousOperation( initial_response = await sync_op(
endpoint=ApiEndpoint( cls,
path=PATH_IMAGE_TO_VIDEO, endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
method=HttpMethod.POST, response_model=RunwayImageToVideoResponse,
request_model=RunwayImageToVideoRequest, data=request,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
) )
initial_response = await initial_operation.execute() final_response = await get_response(cls, initial_response.id, estimated_duration)
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
if not final_response.output: if not final_response.output:
raise RunwayApiError("Runway task succeeded but no video data found in response.") raise RunwayApiError("Runway task succeeded but no video data found in response.")
@ -184,9 +145,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
display_name="Runway Image to Video (Gen3a Turbo)", display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway", category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. " description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that " "Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: " "your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -239,22 +200,18 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls,
start_frame, start_frame,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs,
) )
return IO.NodeOutput( return IO.NodeOutput(
await generate_video( await generate_video(
cls,
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -262,15 +219,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
duration=Duration(duration), duration=Duration(duration),
ratio=AspectRatio(ratio), ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject( promptImage=RunwayPromptImageObject(
root=[ root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
), ),
), ),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
) )
) )
@ -284,9 +235,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
display_name="Runway Image to Video (Gen4 Turbo)", display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway", category="api node/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. " description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that " "Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: " "your input selections will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -339,22 +290,18 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls,
start_frame, start_frame,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs,
) )
return IO.NodeOutput( return IO.NodeOutput(
await generate_video( await generate_video(
cls,
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -362,15 +309,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
duration=Duration(duration), duration=Duration(duration),
ratio=AspectRatio(ratio), ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject( promptImage=RunwayPromptImageObject(
root=[ root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
), ),
), ),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS, estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
) )
) )
@ -385,12 +326,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
display_name="Runway First-Last-Frame to Video", display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway", category="api node/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. " description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different " "More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. " "from the First frame, may benefit from the longer 10s duration. "
"This would give the generation more time to smoothly transition between the two inputs. " "This would give the generation more time to smoothly transition between the two inputs. "
"Before diving in, review these best practices to ensure that your input selections " "Before diving in, review these best practices to ensure that your input selections "
"will set your generation up for success: " "will set your generation up for success: "
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -449,26 +390,22 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
validate_image_dimensions(end_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls,
stacked_input_images, stacked_input_images,
max_images=2, max_images=2,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs,
) )
if len(download_urls) != 2: if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.") raise RunwayApiError("Failed to upload one or more images to comfy api.")
return IO.NodeOutput( return IO.NodeOutput(
await generate_video( await generate_video(
cls,
RunwayImageToVideoRequest( RunwayImageToVideoRequest(
promptText=prompt, promptText=prompt,
seed=seed, seed=seed,
@ -477,17 +414,11 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
ratio=AspectRatio(ratio), ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject( promptImage=RunwayPromptImageObject(
root=[ root=[
RunwayPromptImageDetailedObject( RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
uri=str(download_urls[0]), position="first" RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
] ]
), ),
), ),
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_FLF_SECONDS, estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
) )
) )
@ -502,7 +433,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
display_name="Runway Text to Image", display_name="Runway Text to Image",
category="api node/image/Runway", category="api node/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. " description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.", "You can also include reference image to guide the generation.",
inputs=[ inputs=[
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -540,49 +471,34 @@ class RunwayTextToImageNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1) validate_string(prompt, min_length=1)
auth_kwargs = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
# Prepare reference images if provided # Prepare reference images if provided
reference_images = None reference_images = None
if reference_image is not None: if reference_image is not None:
validate_image_dimensions(reference_image, max_width=7999, max_height=7999) validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
download_urls = await upload_images_to_comfyapi( download_urls = await upload_images_to_comfyapi(
cls,
reference_image, reference_image,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth_kwargs,
) )
reference_images = [ReferenceImage(uri=str(download_urls[0]))] reference_images = [ReferenceImage(uri=str(download_urls[0]))]
request = RunwayTextToImageRequest( initial_response = await sync_op(
promptText=prompt, cls,
model=Model4.gen4_image, endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
ratio=ratio, response_model=RunwayTextToImageResponse,
referenceImages=reference_images, data=RunwayTextToImageRequest(
) promptText=prompt,
model=Model4.gen4_image,
initial_operation = SynchronousOperation( ratio=ratio,
endpoint=ApiEndpoint( referenceImages=reference_images,
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
), ),
request=request,
auth_kwargs=auth_kwargs,
) )
initial_response = await initial_operation.execute()
# Poll for completion
final_response = await get_response( final_response = await get_response(
cls,
initial_response.id, initial_response.id,
auth_kwargs=auth_kwargs,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_T2I_SECONDS, estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
) )
if not final_response.output: if not final_response.output:
@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
RunwayTextToImageNode, RunwayTextToImageNode,
] ]
async def comfy_entrypoint() -> RunwayExtension: async def comfy_entrypoint() -> RunwayExtension:
return RunwayExtension() return RunwayExtension()

View File

@ -1,23 +1,20 @@
from typing import Optional from typing import Optional
from typing_extensions import override
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, IO from typing_extensions import override
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images
from comfy_api_nodes.apinode_utils import ( from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_video_output, download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_bytesio, tensor_to_bytesio,
) )
class Sora2GenerationRequest(BaseModel): class Sora2GenerationRequest(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
model: str = Field(...) model: str = Field(...)
@ -80,7 +77,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
control_after_generate=True, control_after_generate=True,
optional=True, optional=True,
tooltip="Seed to determine if node should re-run; " tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.", "actual results are nondeterministic regardless of seed.",
), ),
], ],
outputs=[ outputs=[
@ -111,55 +108,34 @@ class OpenAIVideoSora2(IO.ComfyNode):
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.") raise ValueError("Currently only one input image is supported.")
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
auth = { initial_response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
} data=Sora2GenerationRequest(
payload = Sora2GenerationRequest( model=model,
model=model, prompt=prompt,
prompt=prompt, seconds=str(duration),
seconds=str(duration), size=size,
size=size,
)
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/openai/v1/videos",
method=HttpMethod.POST,
request_model=Sora2GenerationRequest,
response_model=Sora2GenerationResponse
), ),
request=payload,
files=files_input, files=files_input,
auth_kwargs=auth, response_model=Sora2GenerationResponse,
content_type="multipart/form-data", content_type="multipart/form-data",
) )
initial_response = await initial_operation.execute()
if initial_response.error: if initial_response.error:
raise Exception(initial_response.error.message) raise Exception(initial_response.error["message"])
model_time_multiplier = 1 if model == "sora-2" else 2 model_time_multiplier = 1 if model == "sora-2" else 2
poll_operation = PollingOperation( await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/openai/v1/videos/{initial_response.id}", poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
method=HttpMethod.GET, response_model=Sora2GenerationResponse,
request_model=EmptyRequest,
response_model=Sora2GenerationResponse
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda x: x.status, status_extractor=lambda x: x.status,
auth_kwargs=auth,
poll_interval=8.0, poll_interval=8.0,
max_poll_attempts=160, max_poll_attempts=160,
node_id=cls.hidden.unique_id, estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
estimated_duration=45 * (duration / 4) * model_time_multiplier,
) )
await poll_operation.execute()
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_video_output( await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
f"/proxy/openai/v1/videos/{initial_response.id}/content",
auth_kwargs=auth,
)
) )

View File

@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest, StabilityAudioInpaintRequest,
StabilityAudioResponse, StabilityAudioResponse,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, validate_audio_duration,
HttpMethod, validate_string,
SynchronousOperation, audio_input_to_mp3,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor, bytesio_to_image_tensor,
tensor_to_bytesio, tensor_to_bytesio,
validate_string,
audio_bytes_to_audio_input, audio_bytes_to_audio_input,
audio_input_to_mp3, sync_op,
poll_op,
ApiEndpoint,
) )
from comfy_api_nodes.util.validation_utils import validate_audio_duration
import torch import torch
import base64 import base64
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
creativity=round(creativity,2), creativity=round(creativity,2),
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
} response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
creativity=round(creativity,2), creativity=round(creativity,2),
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/stability/v2beta/results/{response_api.id}", ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
method=HttpMethod.GET, response_model=StabilityResultsGetResponse,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
poll_interval=3, poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x), status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
) )
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS": if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
} response_model=StabilityStableUltraResponse,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000) validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityTextToAudioRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
payload = StabilityAudioToAudioRequest( payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
) )
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityAudioToAudioRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)}, files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
mask_start=mask_start, mask_start=mask_start,
mask_end=mask_end, mask_end=mask_end,
) )
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityAudioInpaintRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)}, files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

View File

@ -0,0 +1,418 @@
import builtins
from io import BytesIO
import aiohttp
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import topaz_api
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
get_fs_object_size,
get_number_of_images,
poll_op,
sync_op,
upload_images_to_comfyapi,
validate_container_format_is_mp4,
)
UPSCALER_MODELS_MAP = {
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
}
UPSCALER_VALUES_MAP = {
"FullHD (1080p)": 1920,
"4K (2160p)": 3840,
}
class TopazImageEnhance(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazImageEnhance",
display_name="Topaz Image Enhance",
category="api node/image/Topaz",
description="Industry-standard upscaling and image enhancement.",
inputs=[
IO.Combo.Input("model", options=["Reimagine"]),
IO.Image.Input("image"),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional text prompt for creative upscaling guidance.",
optional=True,
),
IO.Combo.Input(
"subject_detection",
options=["All", "Foreground", "Background"],
optional=True,
),
IO.Boolean.Input(
"face_enhancement",
default=True,
optional=True,
tooltip="Enhance faces (if present) during processing.",
),
IO.Float.Input(
"face_enhancement_creativity",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Set the creativity level for face enhancement.",
),
IO.Float.Input(
"face_enhancement_strength",
default=1.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Controls how sharp enhanced faces are relative to the background.",
),
IO.Boolean.Input(
"crop_to_fill",
default=False,
optional=True,
tooltip="By default, the image is letterboxed when the output aspect ratio differs. "
"Enable to crop the image to fill the output dimensions.",
),
IO.Int.Input(
"output_width",
default=0,
min=0,
max=32000,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).",
),
IO.Int.Input(
"output_height",
default=0,
min=0,
max=32000,
step=1,
display_mode=IO.NumberDisplay.number,
optional=True,
tooltip="Zero value means to output in the same height as original or output width.",
),
IO.Int.Input(
"creativity",
default=3,
min=1,
max=9,
step=1,
display_mode=IO.NumberDisplay.slider,
optional=True,
),
IO.Boolean.Input(
"face_preservation",
default=True,
optional=True,
tooltip="Preserve subjects' facial identity.",
),
IO.Boolean.Input(
"color_preservation",
default=True,
optional=True,
tooltip="Preserve the original colors.",
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model: str,
image: torch.Tensor,
prompt: str = "",
subject_detection: str = "All",
face_enhancement: bool = True,
face_enhancement_creativity: float = 1.0,
face_enhancement_strength: float = 0.8,
crop_to_fill: bool = False,
output_width: int = 0,
output_height: int = 0,
creativity: int = 3,
face_preservation: bool = True,
color_preservation: bool = True,
) -> IO.NodeOutput:
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
response_model=topaz_api.ImageAsyncTaskResponse,
data=topaz_api.ImageEnhanceRequest(
model=model,
prompt=prompt,
subject_detection=subject_detection,
face_enhancement=face_enhancement,
face_enhancement_creativity=face_enhancement_creativity,
face_enhancement_strength=face_enhancement_strength,
crop_to_fill=crop_to_fill,
output_width=output_width if output_width else None,
output_height=output_height if output_height else None,
creativity=creativity,
face_preservation=str(face_preservation).lower(),
color_preservation=str(color_preservation).lower(),
source_url=download_url[0],
output_format="png",
),
content_type="multipart/form-data",
)
await poll_op(
cls,
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
response_model=topaz_api.ImageStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: x.credits * 0.08,
poll_interval=8.0,
max_poll_attempts=160,
estimated_duration=60,
)
results = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
response_model=topaz_api.ImageDownloadResponse,
monitor_progress=False,
)
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
class TopazVideoEnhance(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
IO.Combo.Input(
"upscaler_creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creativity level (applies only to Starlight (Astra) Creative).",
optional=True,
),
IO.Boolean.Input("interpolation_enabled", default=False, optional=True),
IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
optional=True,
),
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
optional=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
optional=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
optional=True,
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_enabled: bool,
upscaler_model: str,
upscaler_resolution: str,
upscaler_creativity: str = "low",
interpolation_enabled: bool = False,
interpolation_model: str = "apo-8",
interpolation_slowmo: int = 1,
interpolation_frame_rate: int = 60,
interpolation_duplicate: bool = False,
interpolation_duplicate_threshold: float = 0.01,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
if upscaler_enabled is False and interpolation_enabled is False:
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_enabled:
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
filters.append(
topaz_api.VideoEnhancementFilter(
model=UPSCALER_MODELS_MAP[upscaler_model],
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
),
)
if interpolation_enabled:
target_frame_rate = interpolation_frame_rate
filters.append(
topaz_api.VideoFrameInterpolationFilter(
model=interpolation_model,
slowmo=interpolation_slowmo,
fps=interpolation_frame_rate,
duplicate=interpolation_duplicate,
duplicate_threshold=interpolation_duplicate_threshold,
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=topaz_api.CreateVideoResponse,
data=topaz_api.CreateVideoRequest(
source=topaz_api.CreateCreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=topaz_api.Resolution(width=src_width, height=src_height),
),
filters=filters,
output=topaz_api.OutputInformationVideo(
resolution=topaz_api.Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=topaz_api.VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=topaz_api.VideoCompleteUploadResponse,
data=topaz_api.VideoCompleteUploadRequest(
uploadResults=[
topaz_api.VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=topaz_api.VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
max_poll_attempts=320,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TopazImageEnhance,
TopazVideoEnhance,
]
async def comfy_entrypoint() -> TopazExtension:
return TopazExtension()

File diff suppressed because it is too large Load Diff

View File

@ -1,28 +1,21 @@
import logging
import base64 import base64
import aiohttp
import torch
from io import BytesIO from io import BytesIO
from typing import Optional
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoFromFile from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import ( from comfy_api.latest import IO, ComfyExtension
VeoGenVidRequest, from comfy_api_nodes.apis.veo_api import (
VeoGenVidResponse,
VeoGenVidPollRequest, VeoGenVidPollRequest,
VeoGenVidPollResponse, VeoGenVidPollResponse,
VeoGenVidRequest,
VeoGenVidResponse,
) )
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, poll_op,
PollingOperation, sync_op,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
tensor_to_base64_string, tensor_to_base64_string,
) )
@ -35,28 +28,6 @@ MODELS_MAP = {
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
} }
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(IO.ComfyNode): class VeoVideoGenerationNode(IO.ComfyNode):
""" """
@ -169,18 +140,13 @@ class VeoVideoGenerationNode(IO.ComfyNode):
# Prepare the instances for the request # Prepare the instances for the request
instances = [] instances = []
instance = { instance = {"prompt": prompt}
"prompt": prompt
}
# Add image if provided # Add image if provided
if image is not None: if image is not None:
image_base64 = convert_image_to_base64(image) image_base64 = tensor_to_base64_string(image)
if image_base64: if image_base64:
instance["image"] = { instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
"bytesBase64Encoded": image_base64,
"mimeType": "image/png"
}
instances.append(instance) instances.append(instance)
@ -198,119 +164,77 @@ class VeoVideoGenerationNode(IO.ComfyNode):
if seed > 0: if seed > 0:
parameters["seed"] = seed parameters["seed"] = seed
# Only add generateAudio for Veo 3 models # Only add generateAudio for Veo 3 models
if "veo-3.0" in model: if model.find("veo-2.0") == -1:
parameters["generateAudio"] = generate_audio parameters["generateAudio"] = generate_audio
auth = { initial_response = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
} response_model=VeoGenVidResponse,
# Initial request to start video generation data=VeoGenVidRequest(
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/generate",
method=HttpMethod.POST,
request_model=VeoGenVidRequest,
response_model=VeoGenVidResponse
),
request=VeoGenVidRequest(
instances=instances, instances=instances,
parameters=parameters parameters=parameters,
), ),
auth_kwargs=auth,
) )
initial_response = await initial_operation.execute()
operation_name = initial_response.name
logging.info("Veo generation started with operation name: %s", operation_name)
# Define status extractor function
def status_extractor(response): def status_extractor(response):
# Only return "completed" if the operation is done, regardless of success or failure # Only return "completed" if the operation is done, regardless of success or failure
# We'll check for errors after polling completes # We'll check for errors after polling completes
return "completed" if response.done else "pending" return "completed" if response.done else "pending"
# Define progress extractor function poll_response = await poll_op(
def progress_extractor(response): cls,
# Could be enhanced if the API provides progress information ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
return None response_model=VeoGenVidPollResponse,
# Define the polling operation
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/veo/{model}/poll",
method=HttpMethod.POST,
request_model=VeoGenVidPollRequest,
response_model=VeoGenVidPollResponse
),
completed_statuses=["completed"],
failed_statuses=[], # No failed statuses, we'll handle errors after polling
status_extractor=status_extractor, status_extractor=status_extractor,
progress_extractor=progress_extractor, data=VeoGenVidPollRequest(
request=VeoGenVidPollRequest( operationName=initial_response.name,
operationName=operation_name
), ),
auth_kwargs=auth,
poll_interval=5.0, poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=cls.hidden.unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN, estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
) )
# Execute the polling operation
poll_response = await poll_operation.execute()
# Now check for errors in the final response # Now check for errors in the final response
# Check for error in poll response # Check for error in poll response
if hasattr(poll_response, 'error') and poll_response.error: if poll_response.error:
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
logging.error(error_message)
raise Exception(error_message)
# Check for RAI filtered content # Check for RAI filtered content
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and if (
poll_response.response.raiMediaFilteredCount > 0): hasattr(poll_response.response, "raiMediaFilteredCount")
and poll_response.response.raiMediaFilteredCount > 0
):
# Extract reason message if available # Extract reason message if available
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and if (
poll_response.response.raiMediaFilteredReasons): hasattr(poll_response.response, "raiMediaFilteredReasons")
and poll_response.response.raiMediaFilteredReasons
):
reason = poll_response.response.raiMediaFilteredReasons[0] reason = poll_response.response.raiMediaFilteredReasons[0]
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
else: else:
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
logging.error(error_message)
raise Exception(error_message) raise Exception(error_message)
# Extract video data # Extract video data
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0] video = poll_response.response.videos[0]
# Check if video is provided as base64 or URL # Check if video is provided as base64 or URL
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
# Decode base64 string to bytes return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
video_data = base64.b64decode(video.bytesBase64Encoded)
elif hasattr(video, 'gcsUri') and video.gcsUri:
# Download from URL
async with aiohttp.ClientSession() as session:
async with session.get(video.gcsUri) as video_response:
video_data = await video_response.content.read()
else:
raise Exception("Video returned but no data or URL was provided")
else:
raise Exception("Video generation completed but no video was returned")
if not video_data: if hasattr(video, "gcsUri") and video.gcsUri:
raise Exception("No video data was returned") return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
logging.info("Video generation completed successfully") raise Exception("Video returned but no data or URL was provided")
raise Exception("Video generation completed but no video was returned")
# Convert video data to BytesIO object
video_io = BytesIO(video_data)
# Return VideoFromFile object
return IO.NodeOutput(VideoFromFile(video_io))
class Veo3VideoGenerationNode(VeoVideoGenerationNode): class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@ -394,7 +318,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
IO.Combo.Input( IO.Combo.Input(
"model", "model",
options=[ options=[
"veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" "veo-3.1-generate",
"veo-3.1-fast-generate",
"veo-3.0-generate-001",
"veo-3.0-fast-generate-001",
], ],
default="veo-3.0-generate-001", default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation", tooltip="Veo 3 model to use for video generation",
@ -427,5 +354,6 @@ class VeoExtension(ComfyExtension):
Veo3VideoGenerationNode, Veo3VideoGenerationNode,
] ]
async def comfy_entrypoint() -> VeoExtension: async def comfy_entrypoint() -> VeoExtension:
return VeoExtension() return VeoExtension()

View File

@ -1,27 +1,23 @@
import logging import logging
from enum import Enum from enum import Enum
from typing import Any, Callable, Optional, Literal, TypeVar from typing import Literal, Optional, TypeVar
from typing_extensions import override
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.util.validation_utils import ( from comfy_api_nodes.util import (
validate_aspect_ratio_closeness,
validate_image_dimensions,
validate_image_aspect_ratio_range,
get_number_of_images,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint, ApiEndpoint,
HttpMethod, download_url_to_video_output,
SynchronousOperation, get_number_of_images,
PollingOperation, poll_op,
EmptyRequest, sync_op,
upload_images_to_comfyapi,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
) )
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
R = TypeVar("R") R = TypeVar("R")
class VideoModelName(str, Enum): class VideoModelName(str, Enum):
vidu_q1 = 'viduq1' vidu_q1 = "viduq1"
class AspectRatio(str, Enum): class AspectRatio(str, Enum):
@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
class TaskStatus(str, Enum):
created = "created"
queueing = "queueing"
processing = "processing"
success = "success"
failed = "failed"
class TaskCreationResponse(BaseModel): class TaskCreationResponse(BaseModel):
task_id: str = Field(...) task_id: str = Field(...)
state: TaskStatus = Field(...) state: str = Field(...)
created_at: str = Field(...) created_at: str = Field(...)
code: Optional[int] = Field(None, description="Error code") code: Optional[int] = Field(None, description="Error code")
@ -85,32 +74,11 @@ class TaskResult(BaseModel):
class TaskStatusResponse(BaseModel): class TaskStatusResponse(BaseModel):
state: TaskStatus = Field(...) state: str = Field(...)
err_code: Optional[str] = Field(None) err_code: Optional[str] = Field(None)
creations: list[TaskResult] = Field(..., description="Generated results") creations: list[TaskResult] = Field(..., description="Generated results")
async def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
return await PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[TaskStatus.success.value],
failed_statuses=[TaskStatus.failed.value],
status_extractor=lambda response: response.state.value,
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
poll_interval=16.0,
max_poll_attempts=256,
).execute()
def get_video_url_from_response(response) -> Optional[str]: def get_video_url_from_response(response) -> Optional[str]:
if response.creations: if response.creations:
return response.creations[0].url return response.creations[0].url
@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult:
async def execute_task( async def execute_task(
cls: type[IO.ComfyNode],
vidu_endpoint: str, vidu_endpoint: str,
auth_kwargs: Optional[dict[str, str]],
payload: TaskCreationRequest, payload: TaskCreationRequest,
estimated_duration: int, estimated_duration: int,
node_id: str,
) -> R: ) -> R:
response = await SynchronousOperation( response = await sync_op(
endpoint=ApiEndpoint( cls,
path=vidu_endpoint, endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
method=HttpMethod.POST, response_model=TaskCreationResponse,
request_model=TaskCreationRequest, data=payload,
response_model=TaskCreationResponse, )
), if response.state == "failed":
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if response.state == TaskStatus.failed:
error_msg = f"Vidu request failed. Code: {response.code}" error_msg = f"Vidu request failed. Code: {response.code}"
logging.error(error_msg) logging.error(error_msg)
raise RuntimeError(error_msg) raise RuntimeError(error_msg)
return await poll_until_finished( return await poll_op(
auth_kwargs, cls,
ApiEndpoint( ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
path=VIDU_GET_GENERATION_STATUS % response.task_id, response_model=TaskStatusResponse,
method=HttpMethod.GET, status_extractor=lambda r: r.state,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=estimated_duration, estimated_duration=estimated_duration,
node_id=node_id,
) )
@ -258,11 +216,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
resolution=resolution, resolution=resolution,
movement_amplitude=movement_amplitude, movement_amplitude=movement_amplitude,
) )
auth = { results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@ -353,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if get_number_of_images(image) > 1: if get_number_of_images(image) > 1:
raise ValueError("Only one input image is allowed.") raise ValueError("Only one input image is allowed.")
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,
@ -362,17 +316,13 @@ class ViduImageToVideoNode(IO.ComfyNode):
resolution=resolution, resolution=resolution,
movement_amplitude=movement_amplitude, movement_amplitude=movement_amplitude,
) )
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi( payload.images = await upload_images_to_comfyapi(
cls,
image, image,
max_images=1, max_images=1,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth,
) )
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@ -473,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
if a > 7: if a > 7:
raise ValueError("Too many images, maximum allowed is 7.") raise ValueError("Too many images, maximum allowed is 7.")
for image in images: for image in images:
validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) validate_image_aspect_ratio(image, (1, 4), (4, 1))
validate_image_dimensions(image, min_width=128, min_height=128) validate_image_dimensions(image, min_width=128, min_height=128)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
@ -484,17 +434,13 @@ class ViduReferenceVideoNode(IO.ComfyNode):
resolution=resolution, resolution=resolution,
movement_amplitude=movement_amplitude, movement_amplitude=movement_amplitude,
) )
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = await upload_images_to_comfyapi( payload.images = await upload_images_to_comfyapi(
cls,
images, images,
max_images=7, max_images=7,
mime_type="image/png", mime_type="image/png",
auth_kwargs=auth,
) )
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@ -587,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution: str, resolution: str,
movement_amplitude: str, movement_amplitude: str,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest( payload = TaskCreationRequest(
model_name=model, model_name=model,
prompt=prompt, prompt=prompt,
@ -596,15 +542,11 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
resolution=resolution, resolution=resolution,
movement_amplitude=movement_amplitude, movement_amplitude=movement_amplitude,
) )
auth = {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
}
payload.images = [ payload.images = [
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame) for frame in (first_frame, end_frame)
] ]
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
ViduStartEndToVideoNode, ViduStartEndToVideoNode,
] ]
async def comfy_entrypoint() -> ViduExtension: async def comfy_entrypoint() -> ViduExtension:
return ViduExtension() return ViduExtension()

View File

@ -1,28 +1,24 @@
import re import re
from typing import Optional, Type, Union from typing import Optional
from typing_extensions import override
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from comfy_api.latest import ComfyExtension, Input, IO from typing_extensions import override
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
R,
T,
)
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
from comfy_api_nodes.apinode_utils import ( from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.util import (
ApiEndpoint,
audio_to_base64_string,
download_url_to_image_tensor, download_url_to_image_tensor,
download_url_to_video_output, download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string, tensor_to_base64_string,
audio_to_base64_string, validate_audio_duration,
) )
class Text2ImageInputField(BaseModel): class Text2ImageInputField(BaseModel):
prompt: str = Field(...) prompt: str = Field(...)
negative_prompt: Optional[str] = Field(None) negative_prompt: Optional[str] = Field(None)
@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel):
request_id: str = Field(...) request_id: str = Field(...)
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
async def process_task(
auth_kwargs: dict[str, str],
url: str,
request_model: Type[T],
response_model: Type[R],
payload: Union[
Text2ImageTaskCreationRequest,
Image2ImageTaskCreationRequest,
Text2VideoTaskCreationRequest,
Image2VideoTaskCreationRequest,
],
node_id: str,
estimated_duration: int,
poll_interval: int,
) -> Type[R]:
initial_response = await SynchronousOperation(
endpoint=ApiEndpoint(
path=url,
method=HttpMethod.POST,
request_model=request_model,
response_model=TaskCreationResponse,
),
request=payload,
auth_kwargs=auth_kwargs,
).execute()
if not initial_response.output:
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
return await PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=response_model,
),
completed_statuses=["SUCCEEDED"],
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
status_extractor=lambda x: x.output.task_status,
estimated_duration=estimated_duration,
poll_interval=poll_interval,
node_id=node_id,
auth_kwargs=auth_kwargs,
).execute()
class WanTextToImageApi(IO.ComfyNode): class WanTextToImageApi(IO.ComfyNode):
@ -259,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.", tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True, optional=True,
), ),
], ],
@ -286,26 +236,28 @@ class WanTextToImageApi(IO.ComfyNode):
prompt_extend: bool = True, prompt_extend: bool = True,
watermark: bool = True, watermark: bool = True,
): ):
payload = Text2ImageTaskCreationRequest( initial_response = await sync_op(
model=model, cls,
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
parameters=Txt2ImageParametersField( response_model=TaskCreationResponse,
size=f"{width}*{height}", data=Text2ImageTaskCreationRequest(
seed=seed, model=model,
prompt_extend=prompt_extend, input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
watermark=watermark, parameters=Txt2ImageParametersField(
size=f"{width}*{height}",
seed=seed,
prompt_extend=prompt_extend,
watermark=watermark,
),
), ),
) )
response = await process_task( if not initial_response.output:
{ raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
"auth_token": cls.hidden.auth_token_comfy_org, response = await poll_op(
"comfy_api_key": cls.hidden.api_key_comfy_org, cls,
}, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
request_model=Text2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse, response_model=ImageTaskStatusResponse,
payload=payload, status_extractor=lambda x: x.output.task_status,
node_id=cls.hidden.unique_id,
estimated_duration=9, estimated_duration=9,
poll_interval=3, poll_interval=3,
) )
@ -320,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode):
display_name="Wan Image to Image", display_name="Wan Image to Image",
category="api node/image/Wan", category="api node/image/Wan",
description="Generates an image from one or two input images and a text prompt. " description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
"model", "model",
@ -376,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.", tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True, optional=True,
), ),
], ],
@ -408,28 +360,30 @@ class WanImageToImageApi(IO.ComfyNode):
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
images = [] images = []
for i in image: for i in image:
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
payload = Image2ImageTaskCreationRequest( initial_response = await sync_op(
model=model, cls,
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
parameters=Image2ImageParametersField( response_model=TaskCreationResponse,
# size=f"{width}*{height}", data=Image2ImageTaskCreationRequest(
seed=seed, model=model,
watermark=watermark, input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
parameters=Image2ImageParametersField(
# size=f"{width}*{height}",
seed=seed,
watermark=watermark,
),
), ),
) )
response = await process_task( if not initial_response.output:
{ raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
"auth_token": cls.hidden.auth_token_comfy_org, response = await poll_op(
"comfy_api_key": cls.hidden.api_key_comfy_org, cls,
}, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
request_model=Image2ImageTaskCreationRequest,
response_model=ImageTaskStatusResponse, response_model=ImageTaskStatusResponse,
payload=payload, status_extractor=lambda x: x.output.task_status,
node_id=cls.hidden.unique_id,
estimated_duration=42, estimated_duration=42,
poll_interval=3, poll_interval=4,
) )
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
@ -523,7 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.", tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True, optional=True,
), ),
], ],
@ -557,28 +511,31 @@ class WanTextToVideoApi(IO.ComfyNode):
if audio is not None: if audio is not None:
validate_audio_duration(audio, 3.0, 29.0) validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Text2VideoTaskCreationRequest(
model=model, initial_response = await sync_op(
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), cls,
parameters=Text2VideoParametersField( ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
size=f"{width}*{height}", response_model=TaskCreationResponse,
duration=duration, data=Text2VideoTaskCreationRequest(
seed=seed, model=model,
audio=generate_audio, input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
prompt_extend=prompt_extend, parameters=Text2VideoParametersField(
watermark=watermark, size=f"{width}*{height}",
duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
), ),
) )
response = await process_task( if not initial_response.output:
{ raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
"auth_token": cls.hidden.auth_token_comfy_org, response = await poll_op(
"comfy_api_key": cls.hidden.api_key_comfy_org, cls,
}, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Text2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse, response_model=VideoTaskStatusResponse,
payload=payload, status_extractor=lambda x: x.output.task_status,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5), estimated_duration=120 * int(duration / 5),
poll_interval=6, poll_interval=6,
) )
@ -667,7 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode):
IO.Boolean.Input( IO.Boolean.Input(
"watermark", "watermark",
default=True, default=True,
tooltip="Whether to add an \"AI generated\" watermark to the result.", tooltip='Whether to add an "AI generated" watermark to the result.',
optional=True, optional=True,
), ),
], ],
@ -699,35 +656,37 @@ class WanImageToVideoApi(IO.ComfyNode):
): ):
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.") raise ValueError("Exactly one input image is required.")
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
audio_url = None audio_url = None
if audio is not None: if audio is not None:
validate_audio_duration(audio, 3.0, 29.0) validate_audio_duration(audio, 3.0, 29.0)
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
payload = Image2VideoTaskCreationRequest( initial_response = await sync_op(
model=model, cls,
input=Image2VideoInputField( ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url response_model=TaskCreationResponse,
), data=Image2VideoTaskCreationRequest(
parameters=Image2VideoParametersField( model=model,
resolution=resolution, input=Image2VideoInputField(
duration=duration, prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
seed=seed, ),
audio=generate_audio, parameters=Image2VideoParametersField(
prompt_extend=prompt_extend, resolution=resolution,
watermark=watermark, duration=duration,
seed=seed,
audio=generate_audio,
prompt_extend=prompt_extend,
watermark=watermark,
),
), ),
) )
response = await process_task( if not initial_response.output:
{ raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
"auth_token": cls.hidden.auth_token_comfy_org, response = await poll_op(
"comfy_api_key": cls.hidden.api_key_comfy_org, cls,
}, ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
request_model=Image2VideoTaskCreationRequest,
response_model=VideoTaskStatusResponse, response_model=VideoTaskStatusResponse,
payload=payload, status_extractor=lambda x: x.output.task_status,
node_id=cls.hidden.unique_id,
estimated_duration=120 * int(duration / 5), estimated_duration=120 * int(duration / 5),
poll_interval=6, poll_interval=6,
) )

View File

@ -0,0 +1,99 @@
from ._helpers import get_fs_object_size
from .client import (
ApiEndpoint,
poll_op,
poll_op_raw,
sync_op,
sync_op_raw,
)
from .conversions import (
audio_bytes_to_audio_input,
audio_input_to_mp3,
audio_to_base64_string,
bytesio_to_image_tensor,
downscale_image_tensor,
image_tensor_pair_to_batch,
pil_to_bytesio,
resize_mask_to_image,
tensor_to_base64_string,
tensor_to_bytesio,
tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video,
video_to_base64_string,
)
from .download_helpers import (
download_url_as_bytesio,
download_url_to_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
)
from .upload_helpers import (
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
)
from .validation_utils import (
get_image_dimensions,
get_number_of_images,
validate_aspect_ratio_string,
validate_audio_duration,
validate_container_format_is_mp4,
validate_image_aspect_ratio,
validate_image_dimensions,
validate_images_aspect_ratio_closeness,
validate_string,
validate_video_dimensions,
validate_video_duration,
)
__all__ = [
# API client
"ApiEndpoint",
"poll_op",
"poll_op_raw",
"sync_op",
"sync_op_raw",
# Upload helpers
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_images_to_comfyapi",
"upload_video_to_comfyapi",
# Download helpers
"download_url_as_bytesio",
"download_url_to_bytesio",
"download_url_to_image_tensor",
"download_url_to_video_output",
# Conversions
"audio_bytes_to_audio_input",
"audio_input_to_mp3",
"audio_to_base64_string",
"bytesio_to_image_tensor",
"downscale_image_tensor",
"image_tensor_pair_to_batch",
"pil_to_bytesio",
"resize_mask_to_image",
"tensor_to_base64_string",
"tensor_to_bytesio",
"tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video",
"video_to_base64_string",
# Validation utilities
"get_image_dimensions",
"get_number_of_images",
"validate_aspect_ratio_string",
"validate_audio_duration",
"validate_container_format_is_mp4",
"validate_image_aspect_ratio",
"validate_image_dimensions",
"validate_images_aspect_ratio_closeness",
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
# Misc functions
"get_fs_object_size",
]

View File

@ -0,0 +1,71 @@
import asyncio
import contextlib
import os
import time
from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
from comfy_api.latest import IO
from .common_exceptions import ProcessingInterrupted
def is_processing_interrupted() -> bool:
"""Return True if user/runtime requested interruption."""
return processing_interrupted()
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
return node_cls.hidden.unique_id
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
if node_cls.hidden.auth_token_comfy_org:
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
if node_cls.hidden.api_key_comfy_org:
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
return {}
def default_base_url() -> str:
return getattr(args, "comfy_api_base", "https://api.comfy.org")
async def sleep_with_interrupt(
seconds: float,
node_cls: Optional[type[IO.ComfyNode]],
label: Optional[str] = None,
start_ts: Optional[float] = None,
estimated_total: Optional[int] = None,
*,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
):
"""
Sleep in 1s slices while:
- Checking for interruption (raises ProcessingInterrupted).
- Optionally emitting time progress via display_callback (if provided).
"""
end = time.monotonic() + seconds
while True:
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
now = time.monotonic()
if start_ts is not None and label and display_callback:
with contextlib.suppress(Exception):
display_callback(node_cls, label, int(now - start_ts), estimated_total)
if now >= end:
break
await asyncio.sleep(min(1.0, end - now))
def mimetype_to_extension(mime_type: str) -> str:
"""Converts a MIME type to a file extension."""
return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())

View File

@ -0,0 +1,946 @@
import asyncio
import contextlib
import json
import logging
import time
import uuid
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
from urllib.parse import urljoin, urlparse
import aiohttp
from aiohttp.client_exceptions import ClientError, ContentTypeError
from pydantic import BaseModel
from comfy import utils
from comfy_api.latest import IO
from server import PromptServer
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
get_node_id,
is_processing_interrupted,
sleep_with_interrupt,
)
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
M = TypeVar("M", bound=BaseModel)
class ApiEndpoint:
def __init__(
self,
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
query_params: Optional[dict[str, Any]] = None,
headers: Optional[dict[str, str]] = None,
):
self.path = path
self.method = method
self.query_params = query_params or {}
self.headers = headers or {}
@dataclass
class _RequestConfig:
node_cls: type[IO.ComfyNode]
endpoint: ApiEndpoint
timeout: float
content_type: str
data: Optional[dict[str, Any]]
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
multipart_parser: Optional[Callable]
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
estimated_total: Optional[int] = None
final_label_on_success: Optional[str] = "Completed"
progress_origin_ts: Optional[float] = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
@dataclass
class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
price: Optional[float] = None
estimated_duration: Optional[int] = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
response_model: Type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
data: Optional[BaseModel] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
cls,
endpoint,
price_extractor=_wrap_model_extractor(response_model, price_extractor),
data=data,
files=files,
content_type=content_type,
timeout=timeout,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
estimated_duration=estimated_duration,
as_binary=False,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
response_model: Type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[BaseModel] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
cls,
poll_endpoint=poll_endpoint,
status_extractor=_wrap_model_extractor(response_model, status_extractor),
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
price_extractor=_wrap_model_extractor(response_model, price_extractor),
completed_statuses=completed_statuses,
failed_statuses=failed_statuses,
queued_statuses=queued_statuses,
data=data,
poll_interval=poll_interval,
max_poll_attempts=max_poll_attempts,
timeout_per_poll=timeout_per_poll,
max_retries_per_poll=max_retries_per_poll,
retry_delay_per_poll=retry_delay_per_poll,
retry_backoff_per_poll=retry_backoff_per_poll,
estimated_duration=estimated_duration,
cancel_endpoint=cancel_endpoint,
cancel_timeout=cancel_timeout,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
return _validate_or_raise(response_model, raw)
async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
content_type: str = "application/json",
timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None,
as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed",
progress_origin_ts: Optional[float] = None,
monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
for k, v in list(data.items()):
if isinstance(v, Enum):
data[k] = v.value
cfg = _RequestConfig(
node_cls=cls,
endpoint=endpoint,
timeout=timeout,
content_type=content_type,
data=data,
files=files,
multipart_parser=multipart_parser,
max_retries=max_retries,
retry_delay=retry_delay,
retry_backoff=retry_backoff,
wait_label=wait_label,
monitor_progress=monitor_progress,
estimated_total=estimated_duration,
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
)
return await _request_base(cfg, expect_binary=as_binary)
async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
completed_statuses: Optional[list[Union[str, int]]] = None,
failed_statuses: Optional[list[Union[str, int]]] = None,
queued_statuses: Optional[list[Union[str, int]]] = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None,
cancel_endpoint: Optional[ApiEndpoint] = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
Uses default complete, failed and queued states assumption.
Returns the final JSON response from the poll endpoint.
"""
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
started = time.monotonic()
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
async def _ticker():
"""Emit a UI update every second while polling is in progress."""
try:
while not stop_ticker.is_set():
if is_processing_interrupted():
break
now = time.monotonic()
proc_elapsed = state.base_processing_elapsed + (
(now - state.active_since) if state.active_since is not None else 0.0
)
_display_time_progress(
cls,
status=state.status_label,
elapsed_seconds=int(now - state.started),
estimated_total=state.estimated_duration,
price=state.price,
is_queued=state.is_queued,
processing_elapsed_seconds=int(proc_elapsed),
)
await asyncio.sleep(1.0)
except Exception as exc:
logging.debug("Polling ticker exited: %s", exc)
ticker_task = asyncio.create_task(_ticker())
try:
while consumed_attempts < max_poll_attempts:
try:
resp_json = await sync_op_raw(
cls,
poll_endpoint,
data=data,
timeout=timeout_per_poll,
max_retries=max_retries_per_poll,
retry_delay=retry_delay_per_poll,
retry_backoff=retry_backoff_per_poll,
wait_label="Checking",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
if not isinstance(resp_json, dict):
raise Exception("Polling endpoint returned non-JSON response.")
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
try:
status = _normalize_status_value(status_extractor(resp_json))
except Exception as e:
logging.error("Status extraction failed: %s", e)
status = None
if price_extractor:
new_price = price_extractor(resp_json)
if new_price is not None:
state.price = new_price
if progress_extractor:
new_progress = progress_extractor(resp_json)
if new_progress is not None and last_progress != new_progress:
progress_bar.update_absolute(new_progress, total=100)
last_progress = new_progress
now_ts = time.monotonic()
is_queued = status in queued_states
if is_queued:
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
else:
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
state.active_since = now_ts
state.is_queued = is_queued
state.status_label = status or ("Queued" if is_queued else "Processing")
if status in completed_states:
if state.active_since is not None:
state.base_processing_elapsed += now_ts - state.active_since
state.active_since = None
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
if progress_bar and last_progress != 100:
progress_bar.update_absolute(100, total=100)
_display_time_progress(
cls,
status=status if status else "Completed",
elapsed_seconds=int(now_ts - started),
estimated_total=estimated_duration,
price=state.price,
is_queued=False,
processing_elapsed_seconds=int(state.base_processing_elapsed),
)
return resp_json
if status in failed_states:
msg = f"Task failed: {json.dumps(resp_json)}"
logging.error(msg)
raise Exception(msg)
try:
await sleep_with_interrupt(poll_interval, cls, None, None, None)
except ProcessingInterrupted:
if cancel_endpoint:
with contextlib.suppress(Exception):
await sync_op_raw(
cls,
cancel_endpoint,
timeout=cancel_timeout,
max_retries=0,
wait_label="Cancelling task",
estimated_duration=None,
as_binary=False,
final_label_on_success=None,
monitor_progress=False,
)
raise
if not is_queued:
consumed_attempts += 1
raise Exception(
f"Polling timed out after {max_poll_attempts} non-queued attempts "
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
)
except ProcessingInterrupted:
raise
except (LocalNetworkError, ApiServerError):
raise
except Exception as e:
raise Exception(f"Polling aborted due to error: {e}") from e
finally:
stop_ticker.set()
with contextlib.suppress(Exception):
await ticker_task
def _display_text(
node_cls: type[IO.ComfyNode],
text: Optional[str],
*,
status: Optional[Union[str, int]] = None,
price: Optional[float] = None,
) -> None:
display_lines: list[str] = []
if status:
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
if price is not None:
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
if p != "0":
display_lines.append(f"Price: ${p}")
if text is not None:
display_lines.append(text)
if display_lines:
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
def _display_time_progress(
node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]],
elapsed_seconds: int,
estimated_total: Optional[int] = None,
*,
price: Optional[float] = None,
is_queued: Optional[bool] = None,
processing_elapsed_seconds: Optional[int] = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
remaining = max(0, int(estimated_total) - int(pe))
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
else:
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
_display_text(node_cls, time_line, status=status, price=price)
async def _diagnose_connectivity() -> dict[str, bool]:
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
results = {
"internet_accessible": False,
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
if not results["internet_accessible"]:
return results
parsed = urlparse(default_base_url())
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
with contextlib.suppress(ClientError, OSError):
async with session.get(health_url) as resp:
results["api_accessible"] = resp.status < 500
return results
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
"""Normalize (filename, value, content_type)."""
if len(t) == 2:
return t[0], t[1], "application/octet-stream"
if len(t) == 3:
return t[0], t[1], t[2]
raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
if v is not None:
params[k] = v
return params
def _friendly_http_message(status: int, body: Any) -> str:
if status == 401:
return "Unauthorized: Please login first to use this node."
if status == 402:
return "Payment Required: Please add credits to your account to use this node."
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
return "Rate Limit Exceeded: Please try again later."
try:
if isinstance(body, dict):
err = body.get("error")
if isinstance(err, dict):
msg = err.get("message")
typ = err.get("type")
if msg and typ:
return f"API Error: {msg} (Type: {typ})"
if msg:
return f"API Error: {msg}"
return f"API Error: {json.dumps(body)}"
else:
txt = str(body)
if len(txt) <= 200:
return f"API Error (raw): {txt}"
return f"API Error (status {status})"
except Exception:
return f"HTTP {status}: Unknown error"
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
slug = path.strip("/").replace("/", "_") or "op"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
data: Optional[dict[str, Any]],
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
) -> Optional[Union[dict[str, Any], str]]:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
file_fields: list[dict[str, str]] = []
if files:
file_iter = files if isinstance(files, list) else list(files.items())
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename = file_obj[0]
else:
filename = getattr(file_obj, "name", field_name)
file_fields.append({"field": field_name, "filename": str(filename or "")})
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
if content_type == "application/x-www-form-urlencoded":
return data or {}
return data or {}
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
url = cfg.endpoint.path
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
method = cfg.endpoint.method
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
"""Every second: update elapsed time and signal interruption."""
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return # normal shutdown
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None
extracted_price: Optional[float] = None
while True:
attempt += 1
stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None
sess: Optional[aiohttp.ClientSession] = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers:
payload_headers.update(cfg.endpoint.headers)
payload_kw: dict[str, Any] = {"headers": payload_headers}
if method == "GET":
payload_headers.pop("Content-Type", None)
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
try:
if cfg.monitor_progress:
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
sess = aiohttp.ClientSession(timeout=timeout)
if cfg.content_type == "multipart/form-data" and method != "GET":
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
payload_headers.pop("Content-Type", None)
if cfg.multipart_parser and cfg.data:
form = cfg.multipart_parser(cfg.data)
if not isinstance(form, aiohttp.FormData):
raise ValueError("multipart_parser must return aiohttp.FormData")
else:
form = aiohttp.FormData(default_to_multipart=True)
if cfg.data:
for k, v in cfg.data.items():
if v is None:
continue
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
if cfg.files:
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
for field_name, file_obj in file_iter:
if file_obj is None:
continue
if isinstance(file_obj, tuple):
filename, file_value, content_type = _unpack_tuple(file_obj)
else:
filename = getattr(file_obj, "name", field_name)
file_value = file_obj
content_type = "application/octet-stream"
# Attempt to rewind BytesIO for retries
if isinstance(file_value, BytesIO):
with contextlib.suppress(Exception):
file_value.seek(0)
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
payload_kw["data"] = form
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
payload_kw["data"] = cfg.data or {}
elif method != "GET":
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] request logging failed: %s", _log_e)
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
# Race: request vs. monitor (interruption)
tasks = {req_task}
if monitor_task:
tasks.add(monitor_task)
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
if monitor_task and monitor_task in done:
# Interrupted cancel the request and abort
if req_task in pending:
req_task.cancel()
raise ProcessingInterrupted("Task cancelled")
# Otherwise, request finished
resp = await req_task
async with resp:
if resp.status >= 400:
try:
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
logging.warning(
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
method,
url,
resp.status,
delay,
attempt,
cfg.max_retries,
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=_friendly_http_message(resp.status, body),
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
msg = _friendly_http_message(resp.status, body)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=msg,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
raise Exception(msg)
if expect_binary:
buff = bytearray()
last_tick = time.monotonic()
async for chunk in resp.content.iter_chunked(64 * 1024):
buff.extend(chunk)
now = time.monotonic()
if now - last_tick >= 1.0:
last_tick = now
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if cfg.monitor_progress:
_display_time_progress(
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=bytes_payload,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return bytes_payload
else:
try:
payload = await resp.json()
response_content_to_log: Any = payload
except (ContentTypeError, json.JSONDecodeError):
text = await resp.text()
try:
payload = json.loads(text) if text else {}
except json.JSONDecodeError:
payload = {"_raw": text}
response_content_to_log = payload if isinstance(payload, dict) else text
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=response_content_to_log,
)
except Exception as _log_e:
logging.debug("[DEBUG] response logging failed: %s", _log_e)
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
if attempt <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
attempt,
cfg.max_retries,
str(e),
)
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
except Exception as _log_e:
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
await sleep_with_interrupt(
delay,
cfg.node_cls,
cfg.wait_label if cfg.monitor_progress else None,
start_time if cfg.monitor_progress else None,
cfg.estimated_total,
display_callback=_display_time_progress if cfg.monitor_progress else None,
)
delay *= cfg.retry_backoff
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"LocalNetworkError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
error_message=f"ApiServerError: {str(e)}",
)
except Exception as _log_e:
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
) from e
finally:
stop_event.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if sess:
with contextlib.suppress(Exception):
await sess.close()
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
_display_time_progress(
cfg.node_cls,
status=cfg.final_label_on_success,
elapsed_seconds=(
final_elapsed_seconds
if final_elapsed_seconds is not None
else int(time.monotonic() - start_time)
),
estimated_total=cfg.estimated_total,
price=extracted_price,
is_queued=False,
processing_elapsed_seconds=final_elapsed_seconds,
)
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
logging.error(
"Response validation failed for %s: %s",
getattr(response_model, "__name__", response_model),
e,
)
raise Exception(
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
) from e
def _wrap_model_extractor(
response_model: Type[M],
extractor: Optional[Callable[[M], Any]],
) -> Optional[Callable[[dict[str, Any]], Any]]:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
the same response for multiple extractors in a single poll attempt.
"""
if extractor is None:
return None
_cache: dict[int, M] = {}
def _wrapped(d: dict[str, Any]) -> Any:
try:
key = id(d)
model = _cache.get(key)
if model is None:
model = response_model.model_validate(d)
_cache[key] = model
return extractor(model)
except Exception as e:
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
raise
return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
if not values:
return set()
out: set[Union[str, int]] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
out.add(nv)
return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
if isinstance(val, str):
return val.strip().lower()
return val

View File

@ -0,0 +1,14 @@
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
class ProcessingInterrupted(Exception):
"""Operation was interrupted by user/runtime via processing_interrupted()."""

View File

@ -0,0 +1,470 @@
import base64
import logging
import math
import mimetypes
import uuid
from io import BytesIO
from typing import Optional
import av
import numpy as np
import torch
from PIL import Image
from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
"""Converts image data from BytesIO to a torch.Tensor.
Args:
image_bytesio: BytesIO object containing the image data.
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
PIL.UnidentifiedImageError: If the image data cannot be identified.
ValueError: If the specified mode is invalid.
"""
image = Image.open(image_bytesio)
image = image.convert(mode)
image_array = np.array(image).astype(np.float32) / 255.0
return torch.from_numpy(image_array).unsqueeze(0)
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)
def tensor_to_bytesio(
image: torch.Tensor,
name: Optional[str] = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
Args:
image: Input torch.Tensor image.
name: Optional filename for the BytesIO object.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Named BytesIO object containing the image data, with pointer set to the start of buffer.
"""
if not mime_type:
mime_type = "image/png"
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
return img_binary
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
if len(image.shape) > 3:
image = image[0]
# TODO: remove alpha if not allowed and present
input_tensor = image.cpu()
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
img = Image.fromarray(image_np)
return img
def tensor_to_base64_string(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
Returns:
Base64 encoded string of the image.
"""
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
img_bytes = img_byte_arr.getvalue()
# Encode bytes to base64 string
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
return base64_encoded_string
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
"""Converts a PIL Image to a BytesIO object."""
if not mime_type:
mime_type = "image/png"
img_byte_arr = BytesIO()
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
pil_format = mime_type.split("/")[-1].upper()
if pil_format == "JPG":
pil_format = "JPEG"
img.save(img_byte_arr, format=pil_format)
img_byte_arr.seek(0)
return img_byte_arr
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
"""Downscale input image tensor to roughly the specified total pixels."""
samples = image.movedim(-1, 1)
total = int(total_pixels)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
if scale_by >= 1:
return image
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = common_upscale(samples, width, height, "lanczos", "disabled")
s = s.movedim(1, -1)
return s
def tensor_to_data_uri(
image_tensor: torch.Tensor,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> str:
"""Converts a tensor image to a Data URI string.
Args:
image_tensor: Input torch.Tensor image.
total_pixels: Maximum total pixels for potential downscaling.
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
Returns:
Data URI string (e.g., 'data:image/png;base64,...').
"""
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
return f"data:{mime_type};base64,{base64_string}"
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def video_to_base64_string(
video: Input.Video,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def audio_ndarray_to_bytesio(
audio_data_np: np.ndarray,
sample_rate: int,
container_format: str = "mp4",
codec_name: str = "aac",
) -> BytesIO:
"""
Encodes a numpy array of audio data into a BytesIO object.
"""
audio_bytes_io = BytesIO()
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
frame = av.AudioFrame.from_ndarray(
audio_data_np,
format="fltp",
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
)
frame.sample_rate = sample_rate
frame.pts = 0
for packet in audio_stream.encode(frame):
output_container.mux(packet)
# Flush stream
for packet in audio_stream.encode(None):
output_container.mux(packet)
audio_bytes_io.seek(0)
return audio_bytes_io
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
"""
Prepares audio waveform for av library by converting to a contiguous numpy array.
Args:
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
Returns:
Contiguous numpy array of the audio waveform. If the audio was batched,
the first item is taken.
"""
if waveform.ndim != 3 or waveform.shape[0] != 1:
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
# If batch is > 1, take first item
if waveform.shape[0] > 1:
waveform = waveform[0]
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
if audio_data_np.dtype != np.float32:
audio_data_np = audio_data_np.astype(np.float32)
return audio_data_np
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
waveform = audio["waveform"].cpu()
output_buffer = BytesIO()
output_container = av.open(output_buffer, mode="w", format="mp3")
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
out_stream.bit_rate = 320000
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo",
)
frame.sample_rate = audio["sample_rate"]
frame.pts = 0
output_container.mux(out_stream.encode(frame))
output_container.mux(out_stream.encode(None))
output_container.close()
output_buffer.seek(0)
return output_buffer
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
"""
Returns a new VideoInput object trimmed from the beginning to the specified duration,
using av to avoid loading entire video into memory.
Args:
video: Input video to trim
duration_sec: Duration in seconds to keep from the beginning
Returns:
VideoFromFile object that owns the output buffer
"""
output_buffer = BytesIO()
input_container = None
output_container = None
try:
# Get the stream source - this avoids loading entire video into memory
# when the source is already a file path
input_source = video.get_stream_source()
# Open containers
input_container = av.open(input_source, mode="r")
output_container = av.open(output_buffer, mode="w", format="mp4")
# Set up output streams for re-encoding
video_stream = None
audio_stream = None
for stream in input_container.streams:
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
if isinstance(stream, av.VideoStream):
# Create output video stream with same parameters
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
video_stream.width = stream.width
video_stream.height = stream.height
video_stream.pix_fmt = "yuv420p"
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
elif isinstance(stream, av.AudioStream):
# Create output audio stream with same parameters
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
audio_stream.sample_rate = stream.sample_rate
audio_stream.layout = stream.layout
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
# Calculate target frame count that's divisible by 16
fps = input_container.streams.video[0].average_rate
estimated_frames = int(duration_sec * fps)
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
if target_frames == 0:
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
frame_count = 0
audio_frame_count = 0
# Decode and re-encode video frames
if video_stream:
for frame in input_container.decode(video=0):
if frame_count >= target_frames:
break
# Re-encode frame
for packet in video_stream.encode(frame):
output_container.mux(packet)
frame_count += 1
# Flush encoder
for packet in video_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
# Decode and re-encode audio frames
if audio_stream:
input_container.seek(0) # Reset to beginning for audio
for frame in input_container.decode(audio=0):
if frame.time >= duration_sec:
break
# Re-encode frame
for packet in audio_stream.encode(frame):
output_container.mux(packet)
audio_frame_count += 1
# Flush encoder
for packet in audio_stream.encode():
output_container.mux(packet)
logging.info("Encoded %s audio frames", audio_frame_count)
# Close containers
output_container.close()
input_container.close()
# Return as VideoFromFile using the buffer
output_buffer.seek(0)
return InputImpl.VideoFromFile(output_buffer)
except Exception as e:
# Clean up on error
if input_container is not None:
input_container.close()
if output_container is not None:
output_container.close()
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
if wav.dtype.is_floating_point:
return wav
elif wav.dtype == torch.int16:
return wav.float() / (2**15)
elif wav.dtype == torch.int32:
return wav.float() / (2**31)
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
"""
Decode any common audio container from bytes using PyAV and return
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
"""
with av.open(BytesIO(audio_bytes)) as af:
if not af.streams.audio:
raise ValueError("No audio stream found in response.")
stream = af.streams.audio[0]
in_sr = int(stream.codec_context.sample_rate)
out_sr = in_sr
frames: list[torch.Tensor] = []
n_channels = stream.channels or 1
for frame in af.decode(streams=stream.index):
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
buf = torch.from_numpy(arr)
if buf.ndim == 1:
buf = buf.unsqueeze(0) # [T] -> [1, T]
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
elif buf.shape[0] != n_channels:
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
frames.append(buf)
if not frames:
raise ValueError("Decoded zero audio frames.")
wav = torch.cat(frames, dim=1) # [C, T]
wav = _f32_pcm(wav)
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
_, height, width, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
mask = (mask > 0.5).float()
return mask
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@ -0,0 +1,262 @@
import asyncio
import contextlib
import uuid
from io import BytesIO
from pathlib import Path
from typing import IO, Optional, Union
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
from . import request_logger
from ._helpers import (
default_base_url,
get_auth_header,
is_processing_interrupted,
sleep_with_interrupt,
)
from .client import _diagnose_connectivity
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
from .conversions import bytesio_to_image_tensor
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
*,
timeout: Optional[float] = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
cls: type[COMFY_IO.ComfyNode] = None,
) -> None:
"""Stream-download a URL to `dest`.
`dest` must be one of:
- a BytesIO (rewound to 0 after write),
- a file-like object opened in binary write mode (must implement .write()),
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
to an absolute URL and authentication headers can be applied.
Raises:
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
"""
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
attempt = 0
delay = retry_delay
headers: dict[str, str] = {}
parsed_url = urlparse(url)
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
if cls is None:
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
headers = get_auth_header(cls)
while True:
attempt += 1
op_id = _generate_operation_id("GET", url, attempt)
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
session: Optional[aiohttp.ClientSession] = None
stop_evt: Optional[asyncio.Event] = None
monitor_task: Optional[asyncio.Task] = None
req_task: Optional[asyncio.Task] = None
try:
with contextlib.suppress(Exception):
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
session = aiohttp.ClientSession(timeout=timeout_cfg)
stop_evt = asyncio.Event()
async def _monitor():
try:
while not stop_evt.is_set():
if is_processing_interrupted():
return
await asyncio.sleep(1.0)
except asyncio.CancelledError:
return
monitor_task = asyncio.create_task(_monitor())
req_task = asyncio.create_task(session.get(url, headers=headers))
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
if monitor_task in done and req_task in pending:
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
raise ProcessingInterrupted("Task cancelled")
try:
resp = await req_task
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
async with resp:
if resp.status >= 400:
with contextlib.suppress(Exception):
try:
body = await resp.json()
except (ContentTypeError, ValueError):
text = await resp.text()
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
error_message=f"HTTP {resp.status}",
)
if resp.status in _RETRY_STATUS and attempt <= max_retries:
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
raise Exception(f"Failed to download (HTTP {resp.status}).")
if is_path_sink:
p = Path(str(dest))
with contextlib.suppress(Exception):
p.parent.mkdir(parents=True, exist_ok=True)
fhandle = open(p, "wb")
sink = fhandle
else:
sink = dest # BytesIO or file-like
written = 0
while True:
try:
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
except asyncio.TimeoutError:
chunk = b""
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
if is_processing_interrupted():
raise ProcessingInterrupted("Task cancelled")
if not chunk:
if resp.content.at_eof():
break
continue
sink.write(chunk)
written += len(chunk)
if isinstance(dest, BytesIO):
with contextlib.suppress(Exception):
dest.seek(0)
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=f"[streamed {written} bytes to dest]",
)
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
with contextlib.suppress(Exception):
request_logger.log_request_response(
operation_id=op_id,
request_method="GET",
request_url=url,
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
)
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
raise LocalNetworkError(
"Unable to connect to the network. Please check your internet connection and try again."
) from e
raise ApiServerError("The remote service appears unreachable at this time.") from e
finally:
if stop_evt is not None:
stop_evt.set()
if monitor_task:
monitor_task.cancel()
with contextlib.suppress(Exception):
await monitor_task
if req_task and not req_task.done():
req_task.cancel()
with contextlib.suppress(Exception):
await req_task
if session:
with contextlib.suppress(Exception):
await session.close()
if fhandle:
with contextlib.suppress(Exception):
fhandle.flush()
fhandle.close()
async def download_url_to_image_tensor(
url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
result = BytesIO()
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
return bytesio_to_image_tensor(result)
async def download_url_to_video_output(
video_url: str,
*,
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result)
async def download_url_as_bytesio(
url: str,
*,
timeout: float = None,
cls: type[COMFY_IO.ComfyNode] = None,
) -> BytesIO:
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
result = BytesIO()
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
return result
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
try:
parsed = urlparse(url)
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
except Exception:
slug = "download"
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"

View File

@ -1,11 +1,11 @@
from __future__ import annotations from __future__ import annotations
import os
import datetime import datetime
import hashlib
import json import json
import logging import logging
import os
import re import re
import hashlib
from typing import Any from typing import Any
import folder_paths import folder_paths

Some files were not shown because too many files have changed in this diff Show More