mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Merge branch 'master' into preinstall-enhancements
This commit is contained in:
commit
592ce3db7d
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
21
.github/PULL_REQUEST_TEMPLATE/api-node.md
vendored
Normal file
@ -0,0 +1,21 @@
|
||||
<!-- API_NODE_PR_CHECKLIST: do not remove -->
|
||||
|
||||
## API Node PR Checklist
|
||||
|
||||
### Scope
|
||||
- [ ] **Is API Node Change**
|
||||
|
||||
### Pricing & Billing
|
||||
- [ ] **Need pricing update**
|
||||
- [ ] **No pricing update**
|
||||
|
||||
If **Need pricing update**:
|
||||
- [ ] Metronome rate cards updated
|
||||
- [ ] Auto‑billing tests updated and passing
|
||||
|
||||
### QA
|
||||
- [ ] **QA done**
|
||||
- [ ] **QA not required**
|
||||
|
||||
### Comms
|
||||
- [ ] Informed **Kosinkadink**
|
||||
58
.github/workflows/api-node-template.yml
vendored
Normal file
58
.github/workflows/api-node-template.yml
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
name: Append API Node PR template
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [opened, reopened, synchronize, ready_for_review]
|
||||
paths:
|
||||
- 'comfy_api_nodes/**' # only run if these files changed
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
inject:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Ensure template exists and append to PR body
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const { owner, repo } = context.repo;
|
||||
const number = context.payload.pull_request.number;
|
||||
const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md';
|
||||
const marker = '<!-- API_NODE_PR_CHECKLIST: do not remove -->';
|
||||
|
||||
const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number });
|
||||
|
||||
let templateText;
|
||||
try {
|
||||
const res = await github.rest.repos.getContent({
|
||||
owner,
|
||||
repo,
|
||||
path: templatePath,
|
||||
ref: pr.base.ref
|
||||
});
|
||||
const buf = Buffer.from(res.data.content, res.data.encoding || 'base64');
|
||||
templateText = buf.toString('utf8');
|
||||
} catch (e) {
|
||||
core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enforce the presence of the marker inside the template (for idempotence)
|
||||
if (!templateText.includes(marker)) {
|
||||
core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`);
|
||||
return;
|
||||
}
|
||||
|
||||
// If the PR already contains the marker, do not append again.
|
||||
const body = pr.body || '';
|
||||
if (body.includes(marker)) {
|
||||
core.info('Template already present in PR body; nothing to inject.');
|
||||
return;
|
||||
}
|
||||
|
||||
const newBody = (body ? body + '\n\n' : '') + templateText + '\n';
|
||||
await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody });
|
||||
core.notice('API Node template appended to PR description.');
|
||||
19
.github/workflows/release-stable-all.yml
vendored
19
.github/workflows/release-stable-all.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu129)"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
@ -43,6 +43,23 @@ jobs:
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_nvidia_cu126:
|
||||
permissions:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA cu126"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "cu126"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "nvidia"
|
||||
rel_extra_name: "_cu126"
|
||||
test_release: true
|
||||
secrets: inherit
|
||||
|
||||
release_amd_rocm:
|
||||
permissions:
|
||||
contents: "write"
|
||||
|
||||
20
.github/workflows/test-ci.yml
vendored
20
.github/workflows/test-ci.yml
vendored
@ -21,14 +21,15 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
os: [macos, linux]
|
||||
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.10", "3.11", "3.12"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["stable"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
@ -73,14 +74,15 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [macos, linux]
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
python_version: ["3.11"]
|
||||
cuda_version: ["12.1"]
|
||||
torch_version: ["nightly"]
|
||||
include:
|
||||
- os: macos
|
||||
runner_label: [self-hosted, macOS]
|
||||
flags: "--use-pytorch-cross-attention"
|
||||
# - os: macos
|
||||
# runner_label: [self-hosted, macOS]
|
||||
# flags: "--use-pytorch-cross-attention"
|
||||
- os: linux
|
||||
runner_label: [self-hosted, Linux]
|
||||
flags: ""
|
||||
|
||||
168
QUANTIZATION.md
Normal file
168
QUANTIZATION.md
Normal file
@ -0,0 +1,168 @@
|
||||
# The Comfy guide to Quantization
|
||||
|
||||
|
||||
## How does quantization work?
|
||||
|
||||
Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware.
|
||||
|
||||
When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues:
|
||||
- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values
|
||||
- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused"
|
||||
|
||||
By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling.
|
||||
|
||||
```
|
||||
absmax = max(abs(tensor))
|
||||
scale = amax / max_dynamic_range_low_precision
|
||||
|
||||
# Quantization
|
||||
tensor_q = (tensor / scale).to(low_precision_dtype)
|
||||
|
||||
# De-Quantization
|
||||
tensor_dq = tensor_q.to(fp16) * scale
|
||||
|
||||
tensor_dq ~ tensor
|
||||
```
|
||||
|
||||
Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes.
|
||||
|
||||
|
||||
## Quantization in Comfy
|
||||
|
||||
```
|
||||
QuantizedTensor (torch.Tensor subclass)
|
||||
↓ __torch_dispatch__
|
||||
Two-Level Registry (generic + layout handlers)
|
||||
↓
|
||||
MixedPrecisionOps + Metadata Detection
|
||||
```
|
||||
|
||||
### Representation
|
||||
|
||||
To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py`
|
||||
|
||||
A `Layout` class defines how a specific quantization format behaves:
|
||||
- Required parameters
|
||||
- Quantize method
|
||||
- De-Quantize method
|
||||
|
||||
```python
|
||||
from comfy.quant_ops import QuantizedLayout
|
||||
|
||||
class MyLayout(QuantizedLayout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, **kwargs):
|
||||
# Convert to quantized format
|
||||
qdata = ...
|
||||
params = {'scale': ..., 'orig_dtype': tensor.dtype}
|
||||
return qdata, params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
return qdata.to(orig_dtype) * scale
|
||||
```
|
||||
|
||||
To then run operations using these QuantizedTensors we use two registry systems to define supported operations.
|
||||
The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`).
|
||||
|
||||
The second registry is layout-specific and allows to implement fast-paths like nn.Linear.
|
||||
```python
|
||||
from comfy.quant_ops import register_layout_op
|
||||
|
||||
@register_layout_op(torch.ops.aten.linear.default, MyLayout)
|
||||
def my_linear(func, args, kwargs):
|
||||
# Extract tensors, call optimized kernel
|
||||
...
|
||||
```
|
||||
When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation.
|
||||
For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation.
|
||||
|
||||
|
||||
### Mixed Precision
|
||||
|
||||
The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how.
|
||||
|
||||
**Architecture:**
|
||||
|
||||
```python
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {} # Maps layer names to quantization configs
|
||||
_compute_dtype = torch.bfloat16 # Default compute / dequantize precision
|
||||
```
|
||||
|
||||
**Key mechanism:**
|
||||
|
||||
The custom `Linear._load_from_state_dict()` method inspects each layer during model loading:
|
||||
- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype`
|
||||
- If the layer name **is** in `_layer_quant_config`:
|
||||
- Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`)
|
||||
- Load associated quantization parameters (scales, block_size, etc.)
|
||||
|
||||
**Why it's needed:**
|
||||
|
||||
Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality.
|
||||
|
||||
The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode.
|
||||
|
||||
|
||||
## Checkpoint Format
|
||||
|
||||
Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme.
|
||||
|
||||
The quantized checkpoint will contain the same layers as the original checkpoint but:
|
||||
- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8.
|
||||
- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe.
|
||||
- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used.
|
||||
|
||||
### Scaling Parameters details
|
||||
We define 4 possible scaling parameters that should cover most recipes in the near-future:
|
||||
- **weight_scale**: quantization scalers for the weights
|
||||
- **weight_scale_2**: global scalers in the context of double scaling
|
||||
- **pre_quant_scale**: scalers used for smoothing salient weights
|
||||
- **input_scale**: quantization scalers for the activations
|
||||
|
||||
| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale |
|
||||
|--------|---------------|--------------|----------------|-----------------|-------------|
|
||||
| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) |
|
||||
|
||||
You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS).
|
||||
|
||||
### Quantization Metadata
|
||||
|
||||
The metadata stored alongside the checkpoint contains:
|
||||
- **format_version**: String to define a version of the standard
|
||||
- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"_quantization_metadata": {
|
||||
"format_version": "1.0",
|
||||
"layers": {
|
||||
"model.layers.0.mlp.up_proj": "float8_e4m3fn",
|
||||
"model.layers.0.mlp.down_proj": "float8_e4m3fn",
|
||||
"model.layers.1.mlp.up_proj": "float8_e4m3fn"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Creating Quantized Checkpoints
|
||||
|
||||
To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`.
|
||||
|
||||
### Weight Quantization
|
||||
|
||||
Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter.
|
||||
|
||||
### Calibration (for Activation Quantization)
|
||||
|
||||
Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**:
|
||||
|
||||
1. **Collect statistics**: Run inference on N representative samples
|
||||
2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer
|
||||
3. **Compute scales**: Derive `input_scale` from collected statistics
|
||||
4. **Store in checkpoint**: Save `input_scale` parameters alongside weights
|
||||
|
||||
The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters.
|
||||
16
README.md
16
README.md
@ -173,7 +173,7 @@ There is a portable standalone build for Windows that should work for running on
|
||||
|
||||
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
|
||||
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
|
||||
Simply download, extract with [7-Zip](https://7-zip.org) or with the windows explorer on recent windows versions and run. For smaller models you normally only need to put the checkpoints (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints but many of the larger models have multiple files. Make sure to follow the instructions to know which subfolder to put them in ComfyUI\models\
|
||||
|
||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||
|
||||
@ -183,7 +183,9 @@ Update your Nvidia drivers if it doesn't start.
|
||||
|
||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
[Portable with pytorch cuda 12.8 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu128.7z).
|
||||
|
||||
[Portable with pytorch cuda 12.6 and python 3.12](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia_cu126.7z) (Supports Nvidia 10 series and older GPUs).
|
||||
|
||||
#### How do I share models between another UI and ComfyUI?
|
||||
|
||||
@ -200,7 +202,7 @@ comfy install
|
||||
|
||||
## Manual Install (Windows, Linux)
|
||||
|
||||
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
|
||||
Python 3.14 works but you may encounter issues with the torch compile node. The free threaded variant is still missing some dependencies.
|
||||
|
||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||
|
||||
@ -221,7 +223,7 @@ AMD users can install rocm and pytorch with pip if you don't have it already ins
|
||||
|
||||
This is the command to install the nightly with ROCm 7.0 which might have some performance improvements:
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.0```
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm7.1```
|
||||
|
||||
|
||||
### AMD GPUs (Experimental: Windows and Linux), RDNA 3, 3.5 and 4 only.
|
||||
@ -242,7 +244,7 @@ RDNA 4 (RX 9000 series):
|
||||
|
||||
### Intel GPUs (Windows and Linux)
|
||||
|
||||
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
Intel Arc GPU users can install native PyTorch with torch.xpu support using pip. More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||
|
||||
1. To install PyTorch xpu, use the following command:
|
||||
|
||||
@ -252,10 +254,6 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
||||
|
||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||
|
||||
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||
|
||||
1. visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||
|
||||
### NVIDIA
|
||||
|
||||
Nvidia users should install stable pytorch using this command:
|
||||
|
||||
@ -10,7 +10,8 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
@ -160,7 +160,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
|
||||
@ -611,6 +611,66 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.math import attention
|
||||
from comfy.ldm.flux.layers import (
|
||||
MLPEmbedder,
|
||||
RMSNorm,
|
||||
QKNorm,
|
||||
SelfAttention,
|
||||
ModulationOut,
|
||||
)
|
||||
|
||||
# TODO: remove this in a few months
|
||||
SingleStreamBlock = None
|
||||
DoubleStreamBlock = None
|
||||
|
||||
|
||||
class ChromaModulationOut(ModulationOut):
|
||||
@ -48,124 +48,6 @@ class Approximator(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}):
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
|
||||
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
|
||||
|
||||
# calculate the txt bloks
|
||||
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
|
||||
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
class SingleStreamBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with parallel linear layers as described in
|
||||
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor:
|
||||
mod = vec
|
||||
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x.addcmul_(mod.gate, output)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -11,12 +11,12 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.flux.layers import (
|
||||
EmbedND,
|
||||
timestep_embedding,
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
)
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
LastLayer,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
ChromaModulationOut,
|
||||
)
|
||||
@ -90,6 +90,7 @@ class Chroma(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -98,7 +99,7 @@ class Chroma(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
@ -10,12 +10,10 @@ from torch import Tensor, nn
|
||||
from einops import repeat
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock
|
||||
|
||||
from comfy.ldm.chroma.model import Chroma, ChromaParams
|
||||
from comfy.ldm.chroma.layers import (
|
||||
DoubleStreamBlock,
|
||||
SingleStreamBlock,
|
||||
Approximator,
|
||||
)
|
||||
from .layers import (
|
||||
@ -89,7 +87,6 @@ class ChromaRadiance(Chroma):
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
@ -97,6 +94,7 @@ class ChromaRadiance(Chroma):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -109,6 +107,7 @@ class ChromaRadiance(Chroma):
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
modulation=False,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
|
||||
@ -130,13 +130,17 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
self.modulation = modulation
|
||||
|
||||
if self.modulation:
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@ -147,7 +151,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@ -160,46 +166,65 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
if self.modulation:
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
else:
|
||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||
|
||||
# prepare image for attention
|
||||
img_modulated = self.img_norm1(img)
|
||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||
img_qkv = self.img_attn.qkv(img_modulated)
|
||||
del img_modulated
|
||||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del img_qkv
|
||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||
|
||||
# prepare txt for attention
|
||||
txt_modulated = self.txt_norm1(txt)
|
||||
txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
|
||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||
del txt_modulated
|
||||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del txt_qkv
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
if self.flipped_img_txt:
|
||||
q = torch.cat((img_q, txt_q), dim=2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat((img_k, txt_k), dim=2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat((img_v, txt_v), dim=2)
|
||||
del img_v, txt_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
q = torch.cat((txt_q, img_q), dim=2)
|
||||
del txt_q, img_q
|
||||
k = torch.cat((txt_k, img_k), dim=2)
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
attn = attention(q, k, v,
|
||||
pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||
del img_attn
|
||||
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||
del txt_attn
|
||||
txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
@ -220,6 +245,7 @@ class SingleStreamBlock(nn.Module):
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@ -242,19 +268,29 @@ class SingleStreamBlock(nn.Module):
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.modulation = None
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
if self.modulation:
|
||||
mod, _ = self.modulation(vec)
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
@ -7,7 +7,8 @@ import comfy.model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
if pe is not None:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||
return x
|
||||
|
||||
@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@ -42,6 +41,8 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -157,7 +158,10 @@ class TokenRefiner(nn.Module):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
@ -196,11 +200,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -266,6 +274,18 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -276,6 +296,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@ -331,12 +352,31 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@ -430,14 +470,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@ -445,5 +485,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
120
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@ -0,0 +1,120 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
@ -4,8 +4,40 @@ import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class NoPadConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
if isinstance(op, NoPadConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@ -14,7 +46,7 @@ class RMS_norm(nn.Module):
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||
@ -27,11 +59,12 @@ class DnSmpl(nn.Module):
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
|
||||
if self.tds and self.refiner_vae:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
@ -39,14 +72,7 @@ class DnSmpl(nn.Module):
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@ -54,34 +80,32 @@ class DnSmpl(nn.Module):
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
return h + sc
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
@ -94,11 +118,11 @@ class UpSmpl(nn.Module):
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tus and self.refiner_vae:
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
@ -107,14 +131,7 @@ class UpSmpl(nn.Module):
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@ -125,29 +142,43 @@ class UpSmpl(nn.Module):
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
return h + sc
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
class HunyuanRefinerResnetBlock(ResnetBlock):
|
||||
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
|
||||
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = [ self.swish(self.norm1(x)) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
h = [ self.dropout(self.swish(self.norm2(h))) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
@ -160,7 +191,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -175,10 +206,9 @@ class Encoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@ -188,9 +218,9 @@ class Encoder(nn.Module):
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
@ -201,31 +231,50 @@ class Encoder(nn.Module):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
|
||||
x = self.conv_in(x)
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
conv_carry_in = None
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@ -239,7 +288,7 @@ class Decoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = NoPadConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -249,9 +298,9 @@ class Decoder(nn.Module):
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
@ -259,10 +308,9 @@ class Decoder(nn.Module):
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
conv_op=conv_op, norm_op=norm_op)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
if i < depth:
|
||||
@ -275,27 +323,41 @@ class Decoder(nn.Module):
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
if self.refiner_vae:
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
if len(out) > 1:
|
||||
out = torch.cat(out, dim=2)
|
||||
else:
|
||||
out = out[0]
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@ -236,10 +236,10 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
|
||||
|
||||
img_normed = self.img_norm1(hidden_states)
|
||||
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
||||
txt_normed = self.txt_norm1(encoder_hidden_states)
|
||||
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
del img_mod1
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
del txt_mod1
|
||||
|
||||
img_attn_output, txt_attn_output = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
@ -248,16 +248,20 @@ class QwenImageTransformerBlock(nn.Module):
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
del img_modulated
|
||||
del txt_modulated
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn_output
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
||||
del img_attn_output
|
||||
del txt_attn_output
|
||||
del img_gate1
|
||||
del txt_gate1
|
||||
|
||||
img_normed2 = self.img_norm2(hidden_states)
|
||||
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
|
||||
|
||||
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
||||
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
@ -1536,3 +1536,94 @@ class HunyuanImage21Refiner(HunyuanImage21):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
|
||||
|
||||
return out
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
|
||||
if noise_augmentation > 0:
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
|
||||
else:
|
||||
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
@ -186,6 +186,16 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
|
||||
# HunyuanVideo 1.5
|
||||
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_cond_type_embedding"] = True
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
|
||||
@ -504,6 +504,7 @@ class LoadedModel:
|
||||
if use_more_vram == 0:
|
||||
use_more_vram = 1e32
|
||||
self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights)
|
||||
|
||||
real_model = self.model.model
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None:
|
||||
@ -689,7 +690,10 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
lowvram_model_memory = 0.1
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 0.1
|
||||
@ -1100,6 +1104,9 @@ def pin_memory(tensor):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
return False
|
||||
|
||||
@ -1109,6 +1116,9 @@ def pin_memory(tensor):
|
||||
#on the GPU async. So dont trust the CUDA API and guard here
|
||||
return False
|
||||
|
||||
if not tensor.is_contiguous():
|
||||
return False
|
||||
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||
return False
|
||||
|
||||
@ -843,7 +843,7 @@ class ModelPatcher:
|
||||
|
||||
self.object_patches_backup.clear()
|
||||
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False):
|
||||
with self.use_ejected():
|
||||
hooks_unpatched = False
|
||||
memory_freed = 0
|
||||
@ -887,13 +887,19 @@ class ModelPatcher:
|
||||
module_mem += move_weight_functions(m, device_to)
|
||||
if lowvram_possible:
|
||||
if weight_key in self.patches:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
@ -909,6 +915,7 @@ class ModelPatcher:
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
@ -921,6 +928,9 @@ class ModelPatcher:
|
||||
extra_memory += (used - self.model.model_loaded_weight_memory)
|
||||
|
||||
self.patch_model(load_weights=False)
|
||||
if extra_memory < 0 and not unpatch_weights:
|
||||
self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights)
|
||||
return 0
|
||||
full_load = False
|
||||
if self.model.model_lowvram == False and self.model.model_loaded_weight_memory > 0:
|
||||
self.apply_hooks(self.forced_hooks, force_apply=True)
|
||||
|
||||
44
comfy/ops.py
44
comfy/ops.py
@ -58,7 +58,8 @@ except (ModuleNotFoundError, TypeError):
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
@ -77,7 +78,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
# will add async-offload support to your cast and improve performance.
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if isinstance(input, QuantizedTensor):
|
||||
dtype = input._layout_params["orig_dtype"]
|
||||
else:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
@ -110,9 +114,9 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
weight = weight.to(dtype=dtype)
|
||||
if weight_has_function:
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
with wf_context:
|
||||
weight = weight.to(dtype=dtype)
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
@ -534,18 +538,7 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
# Mixed Precision Operations
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor
|
||||
|
||||
QUANT_FORMAT_MIXINS = {
|
||||
"float8_e4m3fn": {
|
||||
"dtype": torch.float8_e4m3fn,
|
||||
"layout_type": "TensorCoreFP8Layout",
|
||||
"parameters": {
|
||||
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||
}
|
||||
}
|
||||
}
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {}
|
||||
@ -596,23 +589,24 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||
self.layout_type = mixin["layout_type"]
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(scale_key)
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name, param_value in mixin["parameters"].items():
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
@ -643,7 +637,7 @@ class MixedPrecisionOps(disable_weight_init):
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
|
||||
@ -74,6 +74,12 @@ def _copy_layout_params(params):
|
||||
new_params[k] = v
|
||||
return new_params
|
||||
|
||||
def _copy_layout_params_inplace(src, dst, non_blocking=False):
|
||||
for k, v in src.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
dst[k].copy_(v, non_blocking=non_blocking)
|
||||
else:
|
||||
dst[k] = v
|
||||
|
||||
class QuantizedLayout:
|
||||
"""
|
||||
@ -318,13 +324,13 @@ def generic_to_dtype_layout(func, args, kwargs):
|
||||
def generic_copy_(func, args, kwargs):
|
||||
qt_dest = args[0]
|
||||
src = args[1]
|
||||
|
||||
non_blocking = args[2] if len(args) > 2 else False
|
||||
if isinstance(qt_dest, QuantizedTensor):
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata)
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
@ -336,6 +342,26 @@ def generic_copy_(func, args, kwargs):
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.empty_like.default)
|
||||
def generic_empty_like(func, args, kwargs):
|
||||
"""Empty_like operation - creates an empty tensor with the same quantized structure."""
|
||||
qt = args[0]
|
||||
if isinstance(qt, QuantizedTensor):
|
||||
# Create empty tensor with same shape and dtype as the quantized data
|
||||
hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"])
|
||||
new_qdata = torch.empty_like(qt._qdata, **kwargs)
|
||||
|
||||
# Handle device transfer for layout params
|
||||
target_device = kwargs.get('device', new_qdata.device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
|
||||
# Update orig_dtype if dtype is specified
|
||||
new_params['orig_dtype'] = hp_dtype
|
||||
|
||||
return QuantizedTensor(new_qdata, qt._layout_type, new_params)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# ==============================================================================
|
||||
# FP8 Layout + Operation Handlers
|
||||
# ==============================================================================
|
||||
@ -378,6 +404,13 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
return qtensor._qdata, qtensor._layout_params['scale']
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreFP8Layout",
|
||||
},
|
||||
}
|
||||
|
||||
LAYOUTS = {
|
||||
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@ -441,20 +441,20 @@ class VAE:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 64
|
||||
self.latent_channels = 32
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.not_video = False
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@ -911,6 +911,7 @@ class CLIPType(Enum):
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
@ -1126,6 +1127,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
|
||||
@ -460,7 +460,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
return embed_out
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
|
||||
def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data={}, tokenizer_args={}):
|
||||
if tokenizer_path is None:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
|
||||
@ -468,6 +468,7 @@ class SDTokenizer:
|
||||
self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length)
|
||||
self.end_token = None
|
||||
self.min_padding = min_padding
|
||||
self.pad_left = pad_left
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
@ -522,6 +523,12 @@ class SDTokenizer:
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def pad_tokens(self, tokens, amount):
|
||||
if self.pad_left:
|
||||
for i in range(amount):
|
||||
tokens.insert(0, (self.pad_token, 1.0, 0))
|
||||
else:
|
||||
tokens.extend([(self.pad_token, 1.0, 0)] * amount)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
|
||||
'''
|
||||
@ -600,7 +607,7 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
self.pad_tokens(batch, remaining_length)
|
||||
#start new batch
|
||||
batch = []
|
||||
if self.start_token is not None:
|
||||
@ -614,11 +621,11 @@ class SDTokenizer:
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if min_padding is not None:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
|
||||
self.pad_tokens(batch, min_padding)
|
||||
if self.pad_to_max_length and len(batch) < self.max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
self.pad_tokens(batch, self.max_length - len(batch))
|
||||
if min_length is not None and len(batch) < min_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
|
||||
self.pad_tokens(batch, min_length - len(batch))
|
||||
|
||||
if not return_word_ids:
|
||||
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
|
||||
|
||||
@ -1374,6 +1374,54 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
"in_channels": 98,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.llama
|
||||
from .hunyuan_image import HunyuanImageTokenizer
|
||||
from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
@ -73,6 +74,14 @@ class HunyuanVideoTokenizer:
|
||||
return {}
|
||||
|
||||
|
||||
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
|
||||
@ -32,6 +32,7 @@ class Llama2Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@ -53,6 +54,7 @@ class Qwen25_3BConfig:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@ -74,6 +76,7 @@ class Qwen25_7BVLI_Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -96,6 +99,7 @@ class Gemma2_2B_Config:
|
||||
k_norm = None
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@ -118,6 +122,7 @@ class Gemma3_4B_Config:
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@ -366,7 +371,12 @@ class Llama2_(nn.Module):
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
@ -421,14 +431,16 @@ class Llama2_(nn.Module):
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if text.startswith('<|start_header_id|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
|
||||
@ -7,7 +7,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io as io
|
||||
from . import _ui as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
@ -104,6 +104,8 @@ class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
|
||||
@ -27,6 +27,7 @@ from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classpr
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
@ -628,6 +629,10 @@ class UpscaleModel(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ImageModelDescriptor
|
||||
|
||||
@comfytype(io_type="LATENT_UPSCALE_MODEL")
|
||||
class LatentUpscaleModel(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO")
|
||||
class Audio(ComfyTypeIO):
|
||||
class AudioDict(TypedDict):
|
||||
@ -656,11 +661,11 @@ class LossMap(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="VOXEL")
|
||||
class Voxel(ComfyTypeIO):
|
||||
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = VOXEL
|
||||
|
||||
@comfytype(io_type="MESH")
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
]
|
||||
|
||||
12
comfy_api/latest/_util/geometry_types.py
Normal file
12
comfy_api/latest/_util/geometry_types.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self.data = data
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
@ -1,22 +1,230 @@
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GeminiSafetyCategory(str, Enum):
|
||||
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
|
||||
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
|
||||
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
|
||||
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
|
||||
|
||||
|
||||
class GeminiSafetyThreshold(str, Enum):
|
||||
OFF = "OFF"
|
||||
BLOCK_NONE = "BLOCK_NONE"
|
||||
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
|
||||
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
|
||||
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
|
||||
|
||||
|
||||
class GeminiSafetySetting(BaseModel):
|
||||
category: GeminiSafetyCategory
|
||||
threshold: GeminiSafetyThreshold
|
||||
|
||||
|
||||
class GeminiRole(str, Enum):
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class GeminiMimeType(str, Enum):
|
||||
application_pdf = "application/pdf"
|
||||
audio_mpeg = "audio/mpeg"
|
||||
audio_mp3 = "audio/mp3"
|
||||
audio_wav = "audio/wav"
|
||||
image_png = "image/png"
|
||||
image_jpeg = "image/jpeg"
|
||||
image_webp = "image/webp"
|
||||
text_plain = "text/plain"
|
||||
video_mov = "video/mov"
|
||||
video_mpeg = "video/mpeg"
|
||||
video_mp4 = "video/mp4"
|
||||
video_mpg = "video/mpg"
|
||||
video_avi = "video/avi"
|
||||
video_wmv = "video/wmv"
|
||||
video_mpegps = "video/mpegps"
|
||||
video_flv = "video/flv"
|
||||
|
||||
|
||||
class GeminiInlineData(BaseModel):
|
||||
data: str | None = Field(
|
||||
None,
|
||||
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
|
||||
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
|
||||
)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiTextPart(BaseModel):
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
parts: list[GeminiPart] = Field([])
|
||||
role: GeminiRole = Field(..., examples=["user"])
|
||||
|
||||
|
||||
class GeminiSystemInstructionContent(BaseModel):
|
||||
parts: list[GeminiTextPart] = Field(
|
||||
...,
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
description: str | None = Field(None)
|
||||
name: str = Field(...)
|
||||
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
|
||||
|
||||
|
||||
class GeminiTool(BaseModel):
|
||||
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
|
||||
|
||||
|
||||
class GeminiOffset(BaseModel):
|
||||
nanos: int | None = Field(None, ge=0, le=999999999)
|
||||
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
|
||||
|
||||
|
||||
class GeminiVideoMetadata(BaseModel):
|
||||
endOffset: GeminiOffset | None = Field(None)
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(40, ge=1)
|
||||
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: Optional[str] = None
|
||||
aspectRatio: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: Optional[list[str]] = None
|
||||
imageConfig: Optional[GeminiImageConfig] = None
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerateContentRequest(BaseModel):
|
||||
contents: list[GeminiContent]
|
||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||
safetySettings: Optional[list[GeminiSafetySetting]] = None
|
||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||
tools: Optional[list[GeminiTool]] = None
|
||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiImageGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class GeminiGenerateContentRequest(BaseModel):
|
||||
contents: list[GeminiContent] = Field(...)
|
||||
generationConfig: GeminiGenerationConfig | None = Field(None)
|
||||
safetySettings: list[GeminiSafetySetting] | None = Field(None)
|
||||
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
|
||||
tools: list[GeminiTool] | None = Field(None)
|
||||
videoMetadata: GeminiVideoMetadata | None = Field(None)
|
||||
|
||||
|
||||
class Modality(str, Enum):
|
||||
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
|
||||
TEXT = "TEXT"
|
||||
IMAGE = "IMAGE"
|
||||
VIDEO = "VIDEO"
|
||||
AUDIO = "AUDIO"
|
||||
DOCUMENT = "DOCUMENT"
|
||||
|
||||
|
||||
class ModalityTokenCount(BaseModel):
|
||||
modality: Modality | None = None
|
||||
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
|
||||
|
||||
|
||||
class Probability(str, Enum):
|
||||
NEGLIGIBLE = "NEGLIGIBLE"
|
||||
LOW = "LOW"
|
||||
MEDIUM = "MEDIUM"
|
||||
HIGH = "HIGH"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class GeminiSafetyRating(BaseModel):
|
||||
category: GeminiSafetyCategory | None = None
|
||||
probability: Probability | None = Field(
|
||||
None,
|
||||
description="The probability that the content violates the specified safety category",
|
||||
)
|
||||
|
||||
|
||||
class GeminiCitation(BaseModel):
|
||||
authors: list[str] | None = None
|
||||
endIndex: int | None = None
|
||||
license: str | None = None
|
||||
publicationDate: date | None = None
|
||||
startIndex: int | None = None
|
||||
title: str | None = None
|
||||
uri: str | None = None
|
||||
|
||||
|
||||
class GeminiCitationMetadata(BaseModel):
|
||||
citations: list[GeminiCitation] | None = None
|
||||
|
||||
|
||||
class GeminiCandidate(BaseModel):
|
||||
citationMetadata: GeminiCitationMetadata | None = None
|
||||
content: GeminiContent | None = None
|
||||
finishReason: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiPromptFeedback(BaseModel):
|
||||
blockReason: str | None = None
|
||||
blockReasonMessage: str | None = None
|
||||
safetyRatings: list[GeminiSafetyRating] | None = None
|
||||
|
||||
|
||||
class GeminiUsageMetadata(BaseModel):
|
||||
cachedContentTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Output only. Number of tokens in the cached part in the input (the cached content).",
|
||||
)
|
||||
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
|
||||
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of candidate tokens by modality."
|
||||
)
|
||||
promptTokenCount: int | None = Field(
|
||||
None,
|
||||
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
|
||||
)
|
||||
promptTokensDetails: list[ModalityTokenCount] | None = Field(
|
||||
None, description="Breakdown of prompt tokens by modality."
|
||||
)
|
||||
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
|
||||
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
|
||||
|
||||
|
||||
class GeminiGenerateContentResponse(BaseModel):
|
||||
candidates: list[GeminiCandidate] | None = Field(None)
|
||||
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||
modelVersion: str | None = Field(None)
|
||||
|
||||
133
comfy_api_nodes/apis/topaz_api.py
Normal file
133
comfy_api_nodes/apis/topaz_api.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageEnhanceRequest(BaseModel):
|
||||
model: str = Field("Reimagine")
|
||||
output_format: str = Field("jpeg")
|
||||
subject_detection: str = Field("All")
|
||||
face_enhancement: bool = Field(True)
|
||||
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
|
||||
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
|
||||
source_url: str = Field(...)
|
||||
output_width: Optional[int] = Field(None)
|
||||
output_height: Optional[int] = Field(None)
|
||||
crop_to_fill: bool = Field(False)
|
||||
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
|
||||
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
|
||||
face_preservation: str = Field("true", description="To preserve the identity of characters")
|
||||
color_preservation: str = Field("true", description="To preserve the original color")
|
||||
|
||||
|
||||
class ImageAsyncTaskResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
|
||||
|
||||
class ImageStatusResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
credits: int = Field(...)
|
||||
|
||||
|
||||
class ImageDownloadResponse(BaseModel):
|
||||
download_url: str = Field(...)
|
||||
expiry: int = Field(...)
|
||||
|
||||
|
||||
class Resolution(BaseModel):
|
||||
width: int = Field(...)
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class CreateCreateVideoRequestSource(BaseModel):
|
||||
container: str = Field(...)
|
||||
size: int = Field(..., description="Size of the video file in bytes")
|
||||
duration: int = Field(..., description="Duration of the video file in seconds")
|
||||
frameCount: int = Field(..., description="Total number of frames in the video")
|
||||
frameRate: int = Field(...)
|
||||
resolution: Resolution = Field(...)
|
||||
|
||||
|
||||
class VideoFrameInterpolationFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
slowmo: Optional[int] = Field(None)
|
||||
fps: int = Field(...)
|
||||
duplicate: bool = Field(...)
|
||||
duplicate_threshold: float = Field(...)
|
||||
|
||||
|
||||
class VideoEnhancementFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
|
||||
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
|
||||
compression: Optional[float] = Field(None, description="Strength of compression recovery")
|
||||
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
|
||||
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
|
||||
noise: Optional[float] = Field(None, description="Amount of noise reduction")
|
||||
halo: Optional[float] = Field(None, description="Amount of halo reduction")
|
||||
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
|
||||
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
resolution: Resolution = Field(...)
|
||||
frameRate: int = Field(...)
|
||||
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
|
||||
audioTransfer: str = Field(..., description="Copy, Convert, None")
|
||||
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
|
||||
|
||||
|
||||
class Overrides(BaseModel):
|
||||
isPaidDiffusion: bool = Field(True)
|
||||
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateCreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
class CreateVideoResponse(BaseModel):
|
||||
requestId: str = Field(...)
|
||||
|
||||
|
||||
class VideoAcceptResponse(BaseModel):
|
||||
uploadId: str = Field(...)
|
||||
urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequestPart(BaseModel):
|
||||
partNum: int = Field(...)
|
||||
eTag: str = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequest(BaseModel):
|
||||
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadResponse(BaseModel):
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
|
||||
|
||||
class VideoStatusResponseEstimates(BaseModel):
|
||||
cost: list[int] = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponseDownloadUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
|
||||
progress: Optional[float] = Field(None)
|
||||
message: Optional[str] = Field("")
|
||||
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)
|
||||
@ -3,8 +3,6 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
@ -12,7 +10,7 @@ import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@ -20,23 +18,24 @@ from typing_extensions import override
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api_nodes.apis import (
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiInlineData,
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
)
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
GeminiImageGenerationConfig,
|
||||
GeminiInlineData,
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
@ -57,6 +56,7 @@ class GeminiModel(str, Enum):
|
||||
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
|
||||
gemini_2_5_pro = "gemini-2.5-pro"
|
||||
gemini_2_5_flash = "gemini-2.5-flash"
|
||||
gemini_3_0_pro = "gemini-3-pro-preview"
|
||||
|
||||
|
||||
class GeminiImageModel(str, Enum):
|
||||
@ -103,6 +103,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
Returns:
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
if response.candidates is None:
|
||||
if response.promptFeedback.blockReason:
|
||||
feedback = response.promptFeedback
|
||||
raise ValueError(
|
||||
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||
)
|
||||
raise NotImplementedError(
|
||||
"Gemini returned no response candidates. "
|
||||
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
||||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||
@ -139,6 +149,49 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash",
|
||||
):
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-2.5-flash-image",
|
||||
):
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-3-pro-preview":
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 120.0
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
||||
return final_price / 1_000_000.0
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@ -272,10 +325,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
audio: Optional[Input.Audio] = None,
|
||||
video: Optional[Input.Video] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
images: torch.Tensor | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -300,15 +353,15 @@ class GeminiNode(IO.ComfyNode):
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role="user",
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
# Get result output
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
@ -406,7 +459,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
||||
def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
|
||||
"""Loads and formats input files for Gemini API."""
|
||||
if GEMINI_INPUT_FILES is None:
|
||||
GEMINI_INPUT_FILES = []
|
||||
@ -421,7 +474,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Google Gemini Image",
|
||||
display_name="Nano Banana (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
@ -469,6 +522,13 @@ class GeminiImage(IO.ComfyNode):
|
||||
"or otherwise generates 1:1 squares.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE+TEXT", "IMAGE"],
|
||||
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -488,9 +548,10 @@ class GeminiImage(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[GeminiPart]] = None,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
@ -510,20 +571,19 @@ class GeminiImage(IO.ComfyNode):
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role="user", parts=parts),
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT", "IMAGE"],
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
@ -544,9 +604,150 @@ class GeminiImage(IO.ComfyNode):
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return IO.NodeOutput(output_image, output_text)
|
||||
|
||||
class GeminiImage2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImage2Node",
|
||||
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||
"Include any constraints, styles, or details the model should follow.",
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gemini-3-pro-image-preview"],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, generates a 1:1 square.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE+TEXT", "IMAGE"],
|
||||
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional reference image(s). "
|
||||
"To include multiple images, use the Batch Images node (up to 14).",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images is not None:
|
||||
if get_number_of_images(images) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(create_image_parts(images))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
image_config = GeminiImageConfig(imageSize=resolution)
|
||||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), output_text)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@ -555,6 +756,7 @@ class GeminiExtension(ComfyExtension):
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -518,7 +518,9 @@ async def execute_lipsync(
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(cls, audio)
|
||||
audio_url = await upload_audio_to_comfyapi(
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
|
||||
)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
|
||||
421
comfy_api_nodes/nodes_topaz.py
Normal file
421
comfy_api_nodes/nodes_topaz.py
Normal file
@ -0,0 +1,421 @@
|
||||
import builtins
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api_nodes.apis import topaz_api
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_fs_object_size,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
}
|
||||
UPSCALER_VALUES_MAP = {
|
||||
"FullHD (1080p)": 1920,
|
||||
"4K (2160p)": 3840,
|
||||
}
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="api node/image/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional text prompt for creative upscaling guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"subject_detection",
|
||||
options=["All", "Foreground", "Background"],
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_enhancement",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Enhance faces (if present) during processing.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_creativity",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Set the creativity level for face enhancement.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_strength",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Controls how sharp enhanced faces are relative to the background.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"crop_to_fill",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="By default, the image is letterboxed when the output aspect ratio differs. "
|
||||
"Enable to crop the image to fill the output dimensions.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to output in the same height as original or output width.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"creativity",
|
||||
default=3,
|
||||
min=1,
|
||||
max=9,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve subjects' facial identity.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"color_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve the original colors.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str = "",
|
||||
subject_detection: str = "All",
|
||||
face_enhancement: bool = True,
|
||||
face_enhancement_creativity: float = 1.0,
|
||||
face_enhancement_strength: float = 0.8,
|
||||
crop_to_fill: bool = False,
|
||||
output_width: int = 0,
|
||||
output_height: int = 0,
|
||||
creativity: int = 3,
|
||||
face_preservation: bool = True,
|
||||
color_preservation: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
response_model=topaz_api.ImageAsyncTaskResponse,
|
||||
data=topaz_api.ImageEnhanceRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
subject_detection=subject_detection,
|
||||
face_enhancement=face_enhancement,
|
||||
face_enhancement_creativity=face_enhancement_creativity,
|
||||
face_enhancement_strength=face_enhancement_strength,
|
||||
crop_to_fill=crop_to_fill,
|
||||
output_width=output_width if output_width else None,
|
||||
output_height=output_height if output_height else None,
|
||||
creativity=creativity,
|
||||
face_preservation=str(face_preservation).lower(),
|
||||
color_preservation=str(color_preservation).lower(),
|
||||
source_url=download_url[0],
|
||||
output_format="png",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: x.credits * 0.08,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=60,
|
||||
)
|
||||
|
||||
results = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageDownloadResponse,
|
||||
monitor_progress=False,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
|
||||
|
||||
|
||||
class TopazVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creativity level (applies only to Starlight (Astra) Creative).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("interpolation_enabled", default=False, optional=True),
|
||||
IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
upscaler_enabled: bool,
|
||||
upscaler_model: str,
|
||||
upscaler_resolution: str,
|
||||
upscaler_creativity: str = "low",
|
||||
interpolation_enabled: bool = False,
|
||||
interpolation_model: str = "apo-8",
|
||||
interpolation_slowmo: int = 1,
|
||||
interpolation_frame_rate: int = 60,
|
||||
interpolation_duplicate: bool = False,
|
||||
interpolation_duplicate_threshold: float = 0.01,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
if upscaler_enabled is False and interpolation_enabled is False:
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
src_width, src_height = video.get_dimensions()
|
||||
video_components = video.get_components()
|
||||
src_frame_rate = int(video_components.frame_rate)
|
||||
duration_sec = video.get_duration()
|
||||
estimated_frames = int(duration_sec * src_frame_rate)
|
||||
validate_container_format_is_mp4(video)
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_enabled:
|
||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
filters.append(
|
||||
topaz_api.VideoEnhancementFilter(
|
||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
),
|
||||
)
|
||||
if interpolation_enabled:
|
||||
target_frame_rate = interpolation_frame_rate
|
||||
filters.append(
|
||||
topaz_api.VideoFrameInterpolationFilter(
|
||||
model=interpolation_model,
|
||||
slowmo=interpolation_slowmo,
|
||||
fps=interpolation_frame_rate,
|
||||
duplicate=interpolation_duplicate,
|
||||
duplicate_threshold=interpolation_duplicate_threshold,
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=topaz_api.CreateVideoResponse,
|
||||
data=topaz_api.CreateVideoRequest(
|
||||
source=topaz_api.CreateCreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=estimated_frames,
|
||||
frameRate=src_frame_rate,
|
||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=topaz_api.OutputInformationVideo(
|
||||
resolution=topaz_api.Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoCompleteUploadResponse,
|
||||
data=topaz_api.VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
topaz_api.VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=topaz_api.VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TopazExtension:
|
||||
return TopazExtension()
|
||||
@ -63,6 +63,7 @@ class _RequestConfig:
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -77,9 +78,9 @@ class _PollUIState:
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
@ -87,6 +88,7 @@ async def sync_op(
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
@ -104,6 +106,7 @@ async def sync_op(
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
endpoint,
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
data=data,
|
||||
files=files,
|
||||
content_type=content_type,
|
||||
@ -175,6 +178,7 @@ async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
content_type: str = "application/json",
|
||||
@ -216,6 +220,7 @@ async def sync_op_raw(
|
||||
estimated_total=estimated_duration,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -424,7 +429,9 @@ def _display_text(
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
display_lines.append(f"Price: ${p}")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
@ -580,6 +587,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
extracted_price: Optional[float] = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
@ -767,6 +775,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
@ -871,7 +881,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
else int(time.monotonic() - start_time)
|
||||
),
|
||||
estimated_total=cfg.estimated_total,
|
||||
price=None,
|
||||
price=extracted_price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=final_elapsed_seconds,
|
||||
)
|
||||
|
||||
@ -11,13 +11,13 @@ if TYPE_CHECKING:
|
||||
|
||||
def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
transformer_options: dict[str] = args[-1]
|
||||
if not isinstance(transformer_options, dict):
|
||||
transformer_options = kwargs.get("transformer_options")
|
||||
if not transformer_options:
|
||||
transformer_options = args[-2]
|
||||
easycache: EasyCacheHolder = transformer_options["easycache"]
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
sigmas = transformer_options["sigmas"]
|
||||
uuids = transformer_options["uuids"]
|
||||
if sigmas is not None and easycache.is_past_end_timestep(sigmas):
|
||||
@ -82,13 +82,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
|
||||
|
||||
def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
|
||||
# get values from args
|
||||
x: torch.Tensor = args[0]
|
||||
timestep: float = args[1]
|
||||
model_options: dict[str] = args[2]
|
||||
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"]
|
||||
if easycache.is_past_end_timestep(timestep):
|
||||
return executor(*args, **kwargs)
|
||||
# prepare next x_prev
|
||||
x: torch.Tensor = args[0][:, :easycache.output_channels]
|
||||
next_x_prev = x
|
||||
input_change = None
|
||||
do_easycache = easycache.should_do_easycache(timestep)
|
||||
@ -173,7 +173,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
||||
|
||||
|
||||
class EasyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||
self.name = "EasyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
@ -202,6 +202,7 @@ class EasyCacheHolder:
|
||||
self.allow_mismatch = True
|
||||
self.cut_from_start = True
|
||||
self.state_metadata = None
|
||||
self.output_channels = output_channels
|
||||
|
||||
def is_past_end_timestep(self, timestep: float) -> bool:
|
||||
return not (timestep[0] > self.end_t).item()
|
||||
@ -264,7 +265,7 @@ class EasyCacheHolder:
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
batch_slice = batch_slice + slicing
|
||||
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||
return x
|
||||
|
||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||
@ -283,7 +284,7 @@ class EasyCacheHolder:
|
||||
else:
|
||||
slicing.append(slice(None))
|
||||
skip_dim = False
|
||||
x = x[slicing]
|
||||
x = x[tuple(slicing)]
|
||||
diff = output - x
|
||||
batch_offset = diff.shape[0] // len(uuids)
|
||||
for i, uuid in enumerate(uuids):
|
||||
@ -323,7 +324,7 @@ class EasyCacheHolder:
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||
|
||||
|
||||
class EasyCacheNode(io.ComfyNode):
|
||||
@ -350,7 +351,7 @@ class EasyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper)
|
||||
@ -358,7 +359,7 @@ class EasyCacheNode(io.ComfyNode):
|
||||
|
||||
|
||||
class LazyCacheHolder:
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False):
|
||||
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None):
|
||||
self.name = "LazyCache"
|
||||
self.reuse_threshold = reuse_threshold
|
||||
self.start_percent = start_percent
|
||||
@ -382,6 +383,7 @@ class LazyCacheHolder:
|
||||
self.approx_output_change_rates = []
|
||||
self.total_steps_skipped = 0
|
||||
self.state_metadata = None
|
||||
self.output_channels = output_channels
|
||||
|
||||
def has_cache_diff(self) -> bool:
|
||||
return self.cache_diff is not None
|
||||
@ -456,7 +458,7 @@ class LazyCacheHolder:
|
||||
return self
|
||||
|
||||
def clone(self):
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose)
|
||||
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels)
|
||||
|
||||
class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -482,7 +484,7 @@ class LazyCacheNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput:
|
||||
model = model.clone()
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose)
|
||||
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper)
|
||||
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
@ -4,7 +4,8 @@ import torch
|
||||
import comfy.model_management
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel
|
||||
import folder_paths
|
||||
|
||||
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -57,6 +58,199 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
schema = super().define_schema()
|
||||
schema.node_id = "EmptyHunyuanVideo15Latent"
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||
# Using scale factor of 16 instead of 8
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device())
|
||||
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device())
|
||||
concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded
|
||||
|
||||
mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||
mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class HunyuanVideo15SuperResolution(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15SuperResolution",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae", optional=True),
|
||||
io.Image.Input("start_image", optional=True),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Latent.Input("latent"),
|
||||
io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01),
|
||||
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput:
|
||||
in_latent = latent["samples"]
|
||||
in_channels = in_latent.shape[1]
|
||||
cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent
|
||||
cond_latent[:, 2 * in_channels + 1] = 1
|
||||
if start_image is not None:
|
||||
start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1)
|
||||
encoded = vae.encode(start_image[:, :, :, :3])
|
||||
cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded
|
||||
cond_latent[:, in_channels + 1, 0] = 1
|
||||
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation})
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
return io.NodeOutput(positive, negative, latent)
|
||||
|
||||
|
||||
class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentUpscaleModelLoader",
|
||||
display_name="Load Latent Upscale Model",
|
||||
category="loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
|
||||
],
|
||||
outputs=[
|
||||
io.LatentUpscaleModel.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name) -> io.NodeOutput:
|
||||
model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name)
|
||||
sd = comfy.utils.load_torch_file(model_path, safe_load=True)
|
||||
|
||||
if "blocks.0.block.0.conv.weight" in sd:
|
||||
config = {
|
||||
"in_channels": sd["in_conv.conv.weight"].shape[1],
|
||||
"out_channels": sd["out_conv.conv.weight"].shape[0],
|
||||
"hidden_channels": sd["in_conv.conv.weight"].shape[0],
|
||||
"num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]),
|
||||
"global_residual": False,
|
||||
}
|
||||
model_type = "720p"
|
||||
elif "up.0.block.0.conv1.conv.weight" in sd:
|
||||
sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()}
|
||||
config = {
|
||||
"z_channels": sd["conv_in.conv.weight"].shape[1],
|
||||
"out_channels": sd["conv_out.conv.weight"].shape[0],
|
||||
"block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))),
|
||||
}
|
||||
model_type = "1080p"
|
||||
|
||||
model = HunyuanVideo15SRModel(model_type, config)
|
||||
model.load_sd(sd)
|
||||
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15LatentUpscaleWithModel",
|
||||
display_name="Hunyuan Video 15 Latent Upscale With Model",
|
||||
category="latent",
|
||||
inputs=[
|
||||
io.LatentUpscaleModel.Input("model"),
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"),
|
||||
io.Int.Input("width", default=1280, min=0, max=16384, step=8),
|
||||
io.Int.Input("height", default=720, min=0, max=16384, step=8),
|
||||
io.Combo.Input("crop", options=["disabled", "center"]),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput:
|
||||
if width == 0 and height == 0:
|
||||
return io.NodeOutput(samples)
|
||||
else:
|
||||
if width == 0:
|
||||
height = max(64, height)
|
||||
width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2]))
|
||||
elif height == 0:
|
||||
width = max(64, width)
|
||||
height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1]))
|
||||
else:
|
||||
width = max(64, width)
|
||||
height = max(64, height)
|
||||
s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop)
|
||||
s = model.resample_latent(s)
|
||||
return io.NodeOutput({"samples": s.cpu().float()})
|
||||
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
@ -210,6 +404,11 @@ class HunyuanExtension(ComfyExtension):
|
||||
CLIPTextEncodeHunyuanDiT,
|
||||
TextEncodeHunyuanVideo_ImageToVideo,
|
||||
EmptyHunyuanLatentVideo,
|
||||
EmptyHunyuanVideo15Latent,
|
||||
HunyuanVideo15ImageToVideo,
|
||||
HunyuanVideo15SuperResolution,
|
||||
HunyuanVideo15LatentUpscaleWithModel,
|
||||
LatentUpscaleModelLoader,
|
||||
HunyuanImageToVideo,
|
||||
EmptyHunyuanImageLatent,
|
||||
HunyuanRefinerLatent,
|
||||
|
||||
@ -7,63 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro
|
||||
import folder_paths
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
||||
|
||||
class EmptyLatentHunyuan3Dv2:
|
||||
|
||||
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentHunyuan3Dv2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
|
||||
def generate(self, resolution, batch_size):
|
||||
@classmethod
|
||||
def execute(cls, resolution, batch_size) -> IO.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
|
||||
return ({"samples": latent, "type": "hunyuan3dv2"}, )
|
||||
return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"})
|
||||
|
||||
class Hunyuan3Dv2Conditioning:
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, clip_vision_output):
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_output) -> IO.NodeOutput:
|
||||
embeds = clip_vision_output.last_hidden_state
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class Hunyuan3Dv2ConditioningMultiView:
|
||||
class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {},
|
||||
"optional": {"front": ("CLIP_VISION_OUTPUT",),
|
||||
"left": ("CLIP_VISION_OUTPUT",),
|
||||
"back": ("CLIP_VISION_OUTPUT",),
|
||||
"right": ("CLIP_VISION_OUTPUT",), }}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("front", optional=True),
|
||||
IO.ClipVisionOutput.Input("left", optional=True),
|
||||
IO.ClipVisionOutput.Input("back", optional=True),
|
||||
IO.ClipVisionOutput.Input("right", optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/video_models"
|
||||
|
||||
def encode(self, front=None, left=None, back=None, right=None):
|
||||
@classmethod
|
||||
def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput:
|
||||
all_embeds = [front, left, back, right]
|
||||
out = []
|
||||
pos_embeds = None
|
||||
@ -76,29 +92,35 @@ class Hunyuan3Dv2ConditioningMultiView:
|
||||
embeds = torch.cat(out, dim=1)
|
||||
positive = [[embeds, {}]]
|
||||
negative = [[torch.zeros_like(embeds), {}]]
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
class VAEDecodeHunyuan3D:
|
||||
class VAEDecodeHunyuan3D(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ),
|
||||
"vae": ("VAE", ),
|
||||
"num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
|
||||
"octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
|
||||
}}
|
||||
RETURN_TYPES = ("VOXEL",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeHunyuan3D",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Int.Input("num_chunks", default=8000, min=1000, max=500000),
|
||||
IO.Int.Input("octree_resolution", default=256, min=16, max=512),
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "latent/3d"
|
||||
@classmethod
|
||||
def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput:
|
||||
voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return IO.NodeOutput(voxels)
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
def decode(self, vae, samples, num_chunks, octree_resolution):
|
||||
voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
|
||||
return (voxels, )
|
||||
|
||||
def voxel_to_mesh(voxels, threshold=0.5, device=None):
|
||||
if device is None:
|
||||
@ -396,24 +418,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
|
||||
|
||||
return final_vertices, faces
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices, faces):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
|
||||
|
||||
class VoxelToMeshBasic:
|
||||
class VoxelToMeshBasic(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMeshBasic",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, threshold):
|
||||
@classmethod
|
||||
def execute(cls, voxel, threshold) -> IO.NodeOutput:
|
||||
vertices = []
|
||||
faces = []
|
||||
for x in voxel.data:
|
||||
@ -421,21 +443,29 @@ class VoxelToMeshBasic:
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
|
||||
class VoxelToMesh:
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
class VoxelToMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"voxel": ("VOXEL", ),
|
||||
"algorithm": (["surface net", "basic"], ),
|
||||
"threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
RETURN_TYPES = ("MESH",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VoxelToMesh",
|
||||
category="3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("voxel"),
|
||||
IO.Combo.Input("algorithm", options=["surface net", "basic"]),
|
||||
IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output(),
|
||||
]
|
||||
)
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def decode(self, voxel, algorithm, threshold):
|
||||
@classmethod
|
||||
def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput:
|
||||
vertices = []
|
||||
faces = []
|
||||
|
||||
@ -449,7 +479,9 @@ class VoxelToMesh:
|
||||
vertices.append(v)
|
||||
faces.append(f)
|
||||
|
||||
return (MESH(torch.stack(vertices), torch.stack(faces)), )
|
||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
@ -581,31 +613,32 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
||||
return filepath
|
||||
|
||||
|
||||
class SaveGLB:
|
||||
class SaveGLB(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"mesh": ("MESH", ),
|
||||
"filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, }
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveGLB",
|
||||
category="3d",
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.String.Input("filename_prefix", default="mesh/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo]
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "3d"
|
||||
|
||||
def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
@classmethod
|
||||
def execute(cls, mesh, filename_prefix) -> IO.NodeOutput:
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||
results = []
|
||||
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
if cls.hidden.prompt is not None:
|
||||
metadata["prompt"] = json.dumps(cls.hidden.prompt)
|
||||
if cls.hidden.extra_pnginfo is not None:
|
||||
for x in cls.hidden.extra_pnginfo:
|
||||
metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
|
||||
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
@ -616,15 +649,22 @@ class SaveGLB:
|
||||
"type": "output"
|
||||
})
|
||||
counter += 1
|
||||
return {"ui": {"3d": results}}
|
||||
return IO.NodeOutput(ui={"3d": results})
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2,
|
||||
"Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning,
|
||||
"Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView,
|
||||
"VAEDecodeHunyuan3D": VAEDecodeHunyuan3D,
|
||||
"VoxelToMeshBasic": VoxelToMeshBasic,
|
||||
"VoxelToMesh": VoxelToMesh,
|
||||
"SaveGLB": SaveGLB,
|
||||
}
|
||||
class Hunyuan3dExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyLatentHunyuan3Dv2,
|
||||
Hunyuan3Dv2Conditioning,
|
||||
Hunyuan3Dv2ConditioningMultiView,
|
||||
VAEDecodeHunyuan3D,
|
||||
VoxelToMeshBasic,
|
||||
VoxelToMesh,
|
||||
SaveGLB,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Hunyuan3dExtension:
|
||||
return Hunyuan3dExtension()
|
||||
|
||||
39
comfy_extras/nodes_nop.py
Normal file
39
comfy_extras/nodes_nop.py
Normal file
@ -0,0 +1,39 @@
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from typing_extensions import override
|
||||
# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list
|
||||
|
||||
# "native" block swap nodes are placebo at best and break the ComfyUI memory management system.
|
||||
# They are also considered harmful because instead of users reporting issues with the built in
|
||||
# memory management they install these stupid nodes and complain even harder. Now it completely
|
||||
# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it
|
||||
# out of all workflows.
|
||||
class wanBlockSwap(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="wanBlockSwap",
|
||||
category="",
|
||||
description="NOP",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(),
|
||||
],
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model) -> io.NodeOutput:
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
class NopExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
wanBlockSwap
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> NopExtension:
|
||||
return NopExtension()
|
||||
@ -39,5 +39,5 @@ NODE_CLASS_MAPPINGS = {
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"PreviewAny": "Preview Any",
|
||||
"PreviewAny": "Preview as Text",
|
||||
}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.3.68"
|
||||
__version__ = "0.3.71"
|
||||
|
||||
@ -38,6 +38,8 @@ folder_names_and_paths["gligen"] = ([os.path.join(models_dir, "gligen")], suppor
|
||||
|
||||
folder_names_and_paths["upscale_models"] = ([os.path.join(models_dir, "upscale_models")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["latent_upscale_models"] = ([os.path.join(models_dir, "latent_upscale_models")], supported_pt_extensions)
|
||||
|
||||
folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes")], set())
|
||||
|
||||
folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions)
|
||||
|
||||
9
nodes.py
9
nodes.py
@ -957,7 +957,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ),
|
||||
},
|
||||
"optional": {
|
||||
"device": (["default", "cpu"], {"advanced": True}),
|
||||
@ -1852,6 +1852,11 @@ class ImageBatch:
|
||||
CATEGORY = "image"
|
||||
|
||||
def batch(self, image1, image2):
|
||||
if image1.shape[-1] != image2.shape[-1]:
|
||||
if image1.shape[-1] > image2.shape[-1]:
|
||||
image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0)
|
||||
else:
|
||||
image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0)
|
||||
if image1.shape[1:] != image2.shape[1:]:
|
||||
image2 = comfy.utils.common_upscale(image2.movedim(-1,1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1,-1)
|
||||
s = torch.cat((image1, image2), dim=0)
|
||||
@ -2330,6 +2335,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_easycache.py",
|
||||
"nodes_audio_encoder.py",
|
||||
"nodes_rope.py",
|
||||
"nodes_nop.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
@ -2358,6 +2364,7 @@ async def init_builtin_api_nodes():
|
||||
"nodes_pika.py",
|
||||
"nodes_runway.py",
|
||||
"nodes_sora.py",
|
||||
"nodes_topaz.py",
|
||||
"nodes_tripo.py",
|
||||
"nodes_moonvalley.py",
|
||||
"nodes_rodin.py",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.3.68"
|
||||
version = "0.3.71"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.9"
|
||||
@ -24,7 +24,7 @@ lint.select = [
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
|
||||
[tool.pylint]
|
||||
master.py-version = "3.9"
|
||||
master.py-version = "3.10"
|
||||
master.extension-pkg-allow-list = [
|
||||
"pydantic",
|
||||
]
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
comfyui-frontend-package==1.30.6
|
||||
comfyui-workflow-templates==0.6.0
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
|
||||
53
server.py
53
server.py
@ -2,6 +2,7 @@ import os
|
||||
import sys
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
|
||||
import nodes
|
||||
import folder_paths
|
||||
@ -29,7 +30,7 @@ import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager
|
||||
from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
|
||||
from app.user_manager import UserManager
|
||||
@ -163,6 +164,22 @@ def create_origin_only_middleware():
|
||||
|
||||
return origin_only_middleware
|
||||
|
||||
|
||||
def create_block_external_middleware():
|
||||
@web.middleware
|
||||
async def block_external_middleware(request: web.Request, handler):
|
||||
if request.method == "OPTIONS":
|
||||
# Pre-flight request. Reply successfully:
|
||||
response = web.Response()
|
||||
else:
|
||||
response = await handler(request)
|
||||
|
||||
response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';"
|
||||
return response
|
||||
|
||||
return block_external_middleware
|
||||
|
||||
|
||||
class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
@ -192,6 +209,9 @@ class PromptServer():
|
||||
else:
|
||||
middlewares.append(create_origin_only_middleware())
|
||||
|
||||
if args.disable_api_nodes:
|
||||
middlewares.append(create_block_external_middleware())
|
||||
|
||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
@ -733,6 +753,7 @@ class PromptServer():
|
||||
for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS:
|
||||
if sensitive_val in extra_data:
|
||||
sensitive[sensitive_val] = extra_data.pop(sensitive_val)
|
||||
extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds
|
||||
self.prompt_queue.put((number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive))
|
||||
response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]}
|
||||
return web.json_response(response)
|
||||
@ -847,11 +868,31 @@ class PromptServer():
|
||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||
|
||||
workflow_templates_path = FrontendManager.templates_path()
|
||||
if workflow_templates_path:
|
||||
self.app.add_routes([
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
installed_templates_version = FrontendManager.get_installed_templates_version()
|
||||
use_legacy_templates = True
|
||||
if installed_templates_version:
|
||||
try:
|
||||
use_legacy_templates = (
|
||||
parse_version(installed_templates_version)
|
||||
< parse_version("0.3.0")
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.warning(
|
||||
"Unable to parse templates version '%s': %s",
|
||||
installed_templates_version,
|
||||
exc,
|
||||
)
|
||||
|
||||
if use_legacy_templates:
|
||||
workflow_templates_path = FrontendManager.legacy_templates_path()
|
||||
if workflow_templates_path:
|
||||
self.app.add_routes([
|
||||
web.static('/templates', workflow_templates_path)
|
||||
])
|
||||
else:
|
||||
handler = FrontendManager.template_asset_handler()
|
||||
if handler:
|
||||
self.app.router.add_get("/templates/{path:.*}", handler)
|
||||
|
||||
# Serve embedded documentation from the package
|
||||
embedded_docs_path = FrontendManager.embedded_docs_path()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user