mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2026-05-17 04:09:09 +08:00
Merge branch 'master' into offloader-maifee
This commit is contained in:
commit
e07a32c9b8
@ -0,0 +1,3 @@
|
|||||||
|
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
|
pause
|
||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
@ -1,2 +1,3 @@
|
|||||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
|
||||||
|
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
|
||||||
pause
|
pause
|
||||||
|
|||||||
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@ -8,13 +8,15 @@ body:
|
|||||||
Before submitting a **Bug Report**, please ensure the following:
|
Before submitting a **Bug Report**, please ensure the following:
|
||||||
|
|
||||||
- **1:** You are running the latest version of ComfyUI.
|
- **1:** You are running the latest version of ComfyUI.
|
||||||
- **2:** You have looked at the existing bug reports and made sure this isn't already reported.
|
- **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report.
|
||||||
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
- **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing
|
||||||
`--disable-all-custom-nodes` command line argument.
|
`--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version.
|
||||||
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
- **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact
|
||||||
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
|
||||||
|
|
||||||
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
|
## Very Important
|
||||||
|
|
||||||
|
Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored.
|
||||||
- type: checkboxes
|
- type: checkboxes
|
||||||
id: custom-nodes-test
|
id: custom-nodes-test
|
||||||
attributes:
|
attributes:
|
||||||
|
|||||||
4
.github/workflows/release-stable-all.yml
vendored
4
.github/workflows/release-stable-all.yml
vendored
@ -18,9 +18,9 @@ jobs:
|
|||||||
uses: ./.github/workflows/stable-release.yml
|
uses: ./.github/workflows/stable-release.yml
|
||||||
with:
|
with:
|
||||||
git_tag: ${{ inputs.git_tag }}
|
git_tag: ${{ inputs.git_tag }}
|
||||||
cache_tag: "cu129"
|
cache_tag: "cu130"
|
||||||
python_minor: "13"
|
python_minor: "13"
|
||||||
python_patch: "6"
|
python_patch: "9"
|
||||||
rel_name: "nvidia"
|
rel_name: "nvidia"
|
||||||
rel_extra_name: ""
|
rel_extra_name: ""
|
||||||
test_release: true
|
test_release: true
|
||||||
|
|||||||
@ -17,7 +17,7 @@ on:
|
|||||||
description: 'cuda version'
|
description: 'cuda version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "129"
|
default: "130"
|
||||||
|
|
||||||
python_minor:
|
python_minor:
|
||||||
description: 'python minor version'
|
description: 'python minor version'
|
||||||
@ -29,7 +29,7 @@ on:
|
|||||||
description: 'python patch version'
|
description: 'python patch version'
|
||||||
required: true
|
required: true
|
||||||
type: string
|
type: string
|
||||||
default: "6"
|
default: "9"
|
||||||
# push:
|
# push:
|
||||||
# branches:
|
# branches:
|
||||||
# - master
|
# - master
|
||||||
|
|||||||
15
README.md
15
README.md
@ -112,10 +112,11 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|
|
||||||
## Release Process
|
## Release Process
|
||||||
|
|
||||||
ComfyUI follows a weekly release cycle targeting Friday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0)
|
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
||||||
|
- Commits outside of the stable release tags may be very unstable and break many custom nodes.
|
||||||
- Serves as the foundation for the desktop release
|
- Serves as the foundation for the desktop release
|
||||||
|
|
||||||
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
|
||||||
@ -176,6 +177,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you
|
|||||||
|
|
||||||
If you have trouble extracting it, right click the file -> properties -> unblock
|
If you have trouble extracting it, right click the file -> properties -> unblock
|
||||||
|
|
||||||
|
Update your Nvidia drivers if it doesn't start.
|
||||||
|
|
||||||
#### Alternative Downloads:
|
#### Alternative Downloads:
|
||||||
|
|
||||||
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
|
||||||
@ -197,7 +200,11 @@ comfy install
|
|||||||
|
|
||||||
## Manual Install (Windows, Linux)
|
## Manual Install (Windows, Linux)
|
||||||
|
|
||||||
Python 3.13 is very well supported. If you have trouble with some custom node dependencies you can try 3.12
|
Python 3.14 will work if you comment out the `kornia` dependency in the requirements.txt file (breaks the canny node) but it is not recommended.
|
||||||
|
|
||||||
|
Python 3.13 is very well supported. If you have trouble with some custom node dependencies on 3.13 you can try 3.12
|
||||||
|
|
||||||
|
### Instructions:
|
||||||
|
|
||||||
Git clone this repo.
|
Git clone this repo.
|
||||||
|
|
||||||
@ -253,7 +260,7 @@ This is the command to install the Pytorch xpu nightly which might have some per
|
|||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
|
|
||||||
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu129```
|
```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu130```
|
||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements.
|
This is the command to install pytorch nightly instead which might have performance improvements.
|
||||||
|
|
||||||
|
|||||||
112
app/subgraph_manager.py
Normal file
112
app/subgraph_manager.py
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TypedDict
|
||||||
|
import os
|
||||||
|
import folder_paths
|
||||||
|
import glob
|
||||||
|
from aiohttp import web
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
class Source:
|
||||||
|
custom_node = "custom_node"
|
||||||
|
|
||||||
|
class SubgraphEntry(TypedDict):
|
||||||
|
source: str
|
||||||
|
"""
|
||||||
|
Source of subgraph - custom_nodes vs templates.
|
||||||
|
"""
|
||||||
|
path: str
|
||||||
|
"""
|
||||||
|
Relative path of the subgraph file.
|
||||||
|
For custom nodes, will be the relative directory like <custom_node_dir>/subgraphs/<name>.json
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
"""
|
||||||
|
Name of subgraph file.
|
||||||
|
"""
|
||||||
|
info: CustomNodeSubgraphEntryInfo
|
||||||
|
"""
|
||||||
|
Additional info about subgraph; in the case of custom_nodes, will contain nodepack name
|
||||||
|
"""
|
||||||
|
data: str
|
||||||
|
|
||||||
|
class CustomNodeSubgraphEntryInfo(TypedDict):
|
||||||
|
node_pack: str
|
||||||
|
"""Node pack name."""
|
||||||
|
|
||||||
|
class SubgraphManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None
|
||||||
|
|
||||||
|
async def load_entry_data(self, entry: SubgraphEntry):
|
||||||
|
with open(entry['path'], 'r') as f:
|
||||||
|
entry['data'] = f.read()
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None:
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
entry = entry.copy()
|
||||||
|
entry.pop('path', None)
|
||||||
|
if remove_data:
|
||||||
|
entry.pop('data', None)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]:
|
||||||
|
entries = entries.copy()
|
||||||
|
for key in list(entries.keys()):
|
||||||
|
entries[key] = await self.sanitize_entry(entries[key], remove_data)
|
||||||
|
return entries
|
||||||
|
|
||||||
|
async def get_custom_node_subgraphs(self, loadedModules, force_reload=False):
|
||||||
|
# if not forced to reload and cached, return cache
|
||||||
|
if not force_reload and self.cached_custom_node_subgraphs is not None:
|
||||||
|
return self.cached_custom_node_subgraphs
|
||||||
|
# Load subgraphs from custom nodes
|
||||||
|
subfolder = "subgraphs"
|
||||||
|
subgraphs_dict: dict[SubgraphEntry] = {}
|
||||||
|
|
||||||
|
for folder in folder_paths.get_folder_paths("custom_nodes"):
|
||||||
|
pattern = os.path.join(folder, f"*/{subfolder}/*.json")
|
||||||
|
matched_files = glob.glob(pattern)
|
||||||
|
for file in matched_files:
|
||||||
|
# replace backslashes with forward slashes
|
||||||
|
file = file.replace('\\', '/')
|
||||||
|
info: CustomNodeSubgraphEntryInfo = {
|
||||||
|
"node_pack": "custom_nodes." + file.split('/')[-3]
|
||||||
|
}
|
||||||
|
source = Source.custom_node
|
||||||
|
# hash source + path to make sure id will be as unique as possible, but
|
||||||
|
# reproducible across backend reloads
|
||||||
|
id = hashlib.sha256(f"{source}{file}".encode()).hexdigest()
|
||||||
|
entry: SubgraphEntry = {
|
||||||
|
"source": Source.custom_node,
|
||||||
|
"name": os.path.splitext(os.path.basename(file))[0],
|
||||||
|
"path": file,
|
||||||
|
"info": info,
|
||||||
|
}
|
||||||
|
subgraphs_dict[id] = entry
|
||||||
|
self.cached_custom_node_subgraphs = subgraphs_dict
|
||||||
|
return subgraphs_dict
|
||||||
|
|
||||||
|
async def get_custom_node_subgraph(self, id: str, loadedModules):
|
||||||
|
subgraphs = await self.get_custom_node_subgraphs(loadedModules)
|
||||||
|
entry: SubgraphEntry = subgraphs.get(id, None)
|
||||||
|
if entry is not None and entry.get('data', None) is None:
|
||||||
|
await self.load_entry_data(entry)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def add_routes(self, routes, loadedModules):
|
||||||
|
@routes.get("/global_subgraphs")
|
||||||
|
async def get_global_subgraphs(request):
|
||||||
|
subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules)
|
||||||
|
# NOTE: we may want to include other sources of global subgraphs such as templates in the future;
|
||||||
|
# that's the reasoning for the current implementation
|
||||||
|
return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True))
|
||||||
|
|
||||||
|
@routes.get("/global_subgraphs/{id}")
|
||||||
|
async def get_global_subgraph(request):
|
||||||
|
id = request.match_info.get("id", None)
|
||||||
|
subgraph = await self.get_custom_node_subgraph(id, loadedModules)
|
||||||
|
return web.json_response(await self.sanitize_entry(subgraph))
|
||||||
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
|||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
@ -156,7 +157,9 @@ class PerformanceFeature(enum.Enum):
|
|||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
|
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|||||||
@ -310,11 +310,13 @@ class ControlLoraOps:
|
|||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -350,12 +352,13 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
|
||||||
|
|||||||
@ -189,15 +189,15 @@ class ChromaRadiance(Chroma):
|
|||||||
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
nerf_pixels = nn.functional.unfold(img_orig, kernel_size=patch_size, stride=patch_size)
|
||||||
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
nerf_pixels = nerf_pixels.transpose(1, 2) # -> [B, NumPatches, C * P * P]
|
||||||
|
|
||||||
|
# Reshape for per-patch processing
|
||||||
|
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
||||||
|
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
||||||
|
|
||||||
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
if params.nerf_tile_size > 0 and num_patches > params.nerf_tile_size:
|
||||||
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
# Enable tiling if nerf_tile_size isn't 0 and we actually have more patches than
|
||||||
# the tile size.
|
# the tile size.
|
||||||
img_dct = self.forward_tiled_nerf(img_out, nerf_pixels, B, C, num_patches, patch_size, params)
|
img_dct = self.forward_tiled_nerf(nerf_hidden, nerf_pixels, B, C, num_patches, patch_size, params)
|
||||||
else:
|
else:
|
||||||
# Reshape for per-patch processing
|
|
||||||
nerf_hidden = img_out.reshape(B * num_patches, params.hidden_size)
|
|
||||||
nerf_pixels = nerf_pixels.reshape(B * num_patches, C, patch_size**2).transpose(1, 2)
|
|
||||||
|
|
||||||
# Get DCT-encoded pixel embeddings [pixel-dct]
|
# Get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct = self.nerf_image_embedder(nerf_pixels)
|
img_dct = self.nerf_image_embedder(nerf_pixels)
|
||||||
|
|
||||||
@ -240,17 +240,8 @@ class ChromaRadiance(Chroma):
|
|||||||
end = min(i + tile_size, num_patches)
|
end = min(i + tile_size, num_patches)
|
||||||
|
|
||||||
# Slice the current tile from the input tensors
|
# Slice the current tile from the input tensors
|
||||||
nerf_hidden_tile = nerf_hidden[:, i:end, :]
|
nerf_hidden_tile = nerf_hidden[i * batch:end * batch]
|
||||||
nerf_pixels_tile = nerf_pixels[:, i:end, :]
|
nerf_pixels_tile = nerf_pixels[i * batch:end * batch]
|
||||||
|
|
||||||
# Get the actual number of patches in this tile (can be smaller for the last tile)
|
|
||||||
num_patches_tile = nerf_hidden_tile.shape[1]
|
|
||||||
|
|
||||||
# Reshape the tile for per-patch processing
|
|
||||||
# [B, NumPatches_tile, D] -> [B * NumPatches_tile, D]
|
|
||||||
nerf_hidden_tile = nerf_hidden_tile.reshape(batch * num_patches_tile, params.hidden_size)
|
|
||||||
# [B, NumPatches_tile, C*P*P] -> [B*NumPatches_tile, C, P*P] -> [B*NumPatches_tile, P*P, C]
|
|
||||||
nerf_pixels_tile = nerf_pixels_tile.reshape(batch * num_patches_tile, channels, patch_size**2).transpose(1, 2)
|
|
||||||
|
|
||||||
# get DCT-encoded pixel embeddings [pixel-dct]
|
# get DCT-encoded pixel embeddings [pixel-dct]
|
||||||
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
img_dct_tile = self.nerf_image_embedder(nerf_pixels_tile)
|
||||||
|
|||||||
@ -195,8 +195,8 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
|
||||||
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
|
||||||
|
|||||||
@ -7,15 +7,7 @@ import comfy.model_management
|
|||||||
|
|
||||||
|
|
||||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||||
q_shape = q.shape
|
q, k = apply_rope(q, k, pe)
|
||||||
k_shape = k.shape
|
|
||||||
|
|
||||||
if pe is not None:
|
|
||||||
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
|
|
||||||
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
|
|
||||||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
|
|
||||||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
|
|
||||||
|
|
||||||
heads = q.shape[1]
|
heads = q.shape[1]
|
||||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -210,7 +210,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def process_img(self, x, index=0, h_offset=0, w_offset=0):
|
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@ -222,10 +222,22 @@ class Flux(nn.Module):
|
|||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
steps_h = h_len
|
||||||
|
steps_w = w_len
|
||||||
|
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
index += rope_options.get("shift_t", 0.0)
|
||||||
|
h_offset += rope_options.get("shift_y", 0.0)
|
||||||
|
w_offset += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
|
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||||
@ -241,7 +253,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
h_len = ((h_orig + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||||
img, img_ids = self.process_img(x)
|
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||||
img_tokens = img.shape[1]
|
img_tokens = img.shape[1]
|
||||||
if ref_latents is not None:
|
if ref_latents is not None:
|
||||||
h = 0
|
h = 0
|
||||||
|
|||||||
@ -3,12 +3,11 @@ from torch import nn
|
|||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
import comfy.ldm.modules.attention
|
import comfy.ldm.modules.attention
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
from einops import rearrange
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
def get_timestep_embedding(
|
def get_timestep_embedding(
|
||||||
timesteps: torch.Tensor,
|
timesteps: torch.Tensor,
|
||||||
@ -238,20 +237,6 @@ class FeedForward(nn.Module):
|
|||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
|
|
||||||
cos_freqs = freqs_cis[0]
|
|
||||||
sin_freqs = freqs_cis[1]
|
|
||||||
|
|
||||||
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
||||||
t1, t2 = t_dup.unbind(dim=-1)
|
|
||||||
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
||||||
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
||||||
|
|
||||||
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class CrossAttention(nn.Module):
|
class CrossAttention(nn.Module):
|
||||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -281,8 +266,8 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.k_norm(k)
|
k = self.k_norm(k)
|
||||||
|
|
||||||
if pe is not None:
|
if pe is not None:
|
||||||
q = apply_rotary_emb(q, pe)
|
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
|
||||||
k = apply_rotary_emb(k, pe)
|
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
@ -306,12 +291,17 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||||
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||||
|
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
||||||
|
x.addcmul_(attn1_input, gate_msa)
|
||||||
|
del attn1_input
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
x += self.ff(y) * gate_mlp
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||||
|
x.addcmul_(self.ff(y), gate_mlp)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):
|
|||||||
|
|
||||||
|
|
||||||
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
|
||||||
dtype = torch.float32 #self.dtype
|
dtype = torch.float32
|
||||||
|
device = indices_grid.device
|
||||||
|
|
||||||
|
# Get fractional positions and compute frequency indices
|
||||||
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
||||||
|
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2
|
||||||
|
|
||||||
start = 1
|
# Compute frequencies and apply cos/sin
|
||||||
end = theta
|
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
|
||||||
device = fractional_positions.device
|
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
|
||||||
|
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)
|
||||||
|
|
||||||
indices = theta ** (
|
# Pad if dim is not divisible by 6
|
||||||
torch.linspace(
|
|
||||||
math.log(start, theta),
|
|
||||||
math.log(end, theta),
|
|
||||||
dim // 6,
|
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
indices = indices.to(dtype=dtype)
|
|
||||||
|
|
||||||
indices = indices * math.pi / 2
|
|
||||||
|
|
||||||
freqs = (
|
|
||||||
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
||||||
.transpose(-1, -2)
|
|
||||||
.flatten(2)
|
|
||||||
)
|
|
||||||
|
|
||||||
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
||||||
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
||||||
if dim % 6 != 0:
|
if dim % 6 != 0:
|
||||||
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
padding_size = dim % 6
|
||||||
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
|
||||||
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)
|
||||||
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
||||||
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
|
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
|
||||||
|
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
|
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
|
||||||
|
|
||||||
|
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
|
||||||
|
freqs_cis = torch.stack([
|
||||||
|
torch.stack([cos_vals, -sin_vals], dim=-1),
|
||||||
|
torch.stack([sin_vals, cos_vals], dim=-1)
|
||||||
|
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]
|
||||||
|
|
||||||
|
return freqs_cis
|
||||||
|
|
||||||
|
|
||||||
class LTXVModel(torch.nn.Module):
|
class LTXVModel(torch.nn.Module):
|
||||||
@ -501,7 +485,7 @@ class LTXVModel(torch.nn.Module):
|
|||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
# Modulation
|
# Modulation
|
||||||
x = x * (1 + scale) + shift
|
x = torch.addcmul(x, x, scale).add_(shift)
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
|
|
||||||
x = self.patchifier.unpatchify(
|
x = self.patchifier.unpatchify(
|
||||||
|
|||||||
@ -522,7 +522,7 @@ class NextDiT(nn.Module):
|
|||||||
max_cap_len = max(l_effective_cap_len)
|
max_cap_len = max(l_effective_cap_len)
|
||||||
max_img_len = max(l_effective_img_len)
|
max_img_len = max(l_effective_img_len)
|
||||||
|
|
||||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
|
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
cap_len = l_effective_cap_len[i]
|
cap_len = l_effective_cap_len[i]
|
||||||
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
|
|||||||
H_tokens, W_tokens = H // pH, W // pW
|
H_tokens, W_tokens = H // pH, W // pW
|
||||||
assert H_tokens * W_tokens == img_len
|
assert H_tokens * W_tokens == img_len
|
||||||
|
|
||||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
h_scale = 1.0
|
||||||
|
w_scale = 1.0
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
if rope_options is not None:
|
||||||
|
h_scale = rope_options.get("scale_y", 1.0)
|
||||||
|
w_scale = rope_options.get("scale_x", 1.0)
|
||||||
|
|
||||||
|
h_start = rope_options.get("shift_y", 0.0)
|
||||||
|
w_start = rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
|
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||||
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
|
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||||
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
|
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
0
comfy/ldm/mmaudio/vae/__init__.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
120
comfy/ldm/mmaudio/vae/activations.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn, sin, pow
|
||||||
|
from torch.nn import Parameter
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
'''
|
||||||
|
Implementation of a sine-based periodic activation function
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter
|
||||||
|
References:
|
||||||
|
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snake(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
'''
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
'''
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha: trainable parameter
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
'''
|
||||||
|
super(Snake, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale:
|
||||||
|
self.alpha = Parameter(torch.empty(in_features))
|
||||||
|
else:
|
||||||
|
self.alpha = Parameter(torch.empty(in_features))
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
Snake ∶= x + 1/a * sin^2 (xa)
|
||||||
|
'''
|
||||||
|
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
'''
|
||||||
|
A modified Snake function which uses separate parameters for the magnitude of the periodic components
|
||||||
|
Shape:
|
||||||
|
- Input: (B, C, T)
|
||||||
|
- Output: (B, C, T), same shape as the input
|
||||||
|
Parameters:
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
References:
|
||||||
|
- This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||||
|
https://arxiv.org/abs/2006.08195
|
||||||
|
Examples:
|
||||||
|
>>> a1 = snakebeta(256)
|
||||||
|
>>> x = torch.randn(256)
|
||||||
|
>>> x = a1(x)
|
||||||
|
'''
|
||||||
|
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||||
|
'''
|
||||||
|
Initialization.
|
||||||
|
INPUT:
|
||||||
|
- in_features: shape of the input
|
||||||
|
- alpha - trainable parameter that controls frequency
|
||||||
|
- beta - trainable parameter that controls magnitude
|
||||||
|
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||||
|
beta is initialized to 1 by default, higher values = higher-magnitude.
|
||||||
|
alpha will be trained along with the rest of your model.
|
||||||
|
'''
|
||||||
|
super(SnakeBeta, self).__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
# initialize alpha
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
if self.alpha_logscale:
|
||||||
|
self.alpha = Parameter(torch.empty(in_features))
|
||||||
|
self.beta = Parameter(torch.empty(in_features))
|
||||||
|
else:
|
||||||
|
self.alpha = Parameter(torch.empty(in_features))
|
||||||
|
self.beta = Parameter(torch.empty(in_features))
|
||||||
|
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
|
||||||
|
self.no_div_by_zero = 0.000000001
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
'''
|
||||||
|
Forward pass of the function.
|
||||||
|
Applies the function to the input elementwise.
|
||||||
|
SnakeBeta ∶= x + 1/b * sin^2 (xa)
|
||||||
|
'''
|
||||||
|
alpha = comfy.model_management.cast_to(self.alpha, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||||
|
beta = comfy.model_management.cast_to(self.beta, dtype=x.dtype, device=x.device).unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
alpha = torch.exp(alpha)
|
||||||
|
beta = torch.exp(beta)
|
||||||
|
x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||||
|
|
||||||
|
return x
|
||||||
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
157
comfy/ldm/mmaudio/vae/alias_free_torch.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
if 'sinc' in dir(torch):
|
||||||
|
sinc = torch.sinc
|
||||||
|
else:
|
||||||
|
# This code is adopted from adefossez's julius.core.sinc under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/core.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def sinc(x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Implementation of sinc, i.e. sin(pi * x) / (pi * x)
|
||||||
|
__Warning__: Different to julius.sinc, the input is multiplied by `pi`!
|
||||||
|
"""
|
||||||
|
return torch.where(x == 0,
|
||||||
|
torch.tensor(1., device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x)
|
||||||
|
|
||||||
|
|
||||||
|
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||||
|
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||||
|
even = (kernel_size % 2 == 0)
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
|
||||||
|
#For kaiser window
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.:
|
||||||
|
beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
|
||||||
|
else:
|
||||||
|
beta = 0.
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
|
||||||
|
# ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
|
||||||
|
if even:
|
||||||
|
time = (torch.arange(-half_size, half_size) + 0.5)
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
|
||||||
|
# Normalize filter to have sum = 1, otherwise we will have a small leakage
|
||||||
|
# of the constant component in the input signal.
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride: int = 1,
|
||||||
|
padding: bool = True,
|
||||||
|
padding_mode: str = 'replicate',
|
||||||
|
kernel_size: int = 12):
|
||||||
|
# kernel_size should be even number for stylegan3 setup,
|
||||||
|
# in this implementation, odd number is also possible.
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = (kernel_size % 2 == 0)
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
#input [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right),
|
||||||
|
mode=self.padding_mode)
|
||||||
|
out = F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device),
|
||||||
|
stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.stride = ratio
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
# x: [B, C, T]
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode='replicate')
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
|
||||||
|
x = x[..., self.pad_left:-self.pad_right]
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
xx = self.lowpass(x)
|
||||||
|
|
||||||
|
return xx
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
activation,
|
||||||
|
up_ratio: int = 2,
|
||||||
|
down_ratio: int = 2,
|
||||||
|
up_kernel_size: int = 12,
|
||||||
|
down_kernel_size: int = 12):
|
||||||
|
super().__init__()
|
||||||
|
self.up_ratio = up_ratio
|
||||||
|
self.down_ratio = down_ratio
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
# x: [B,C,T]
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return x
|
||||||
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
156
comfy/ldm/mmaudio/vae/autoencoder.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .distributions import DiagonalGaussianDistribution
|
||||||
|
from .vae import VAE_16k
|
||||||
|
from .bigvgan import BigVGANVocoder
|
||||||
|
import logging
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torchaudio
|
||||||
|
except:
|
||||||
|
logging.warning("torchaudio missing, MMAudio VAE model will be broken")
|
||||||
|
|
||||||
|
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5, *, norm_fn):
|
||||||
|
return norm_fn(torch.clamp(x, min=clip_val) * C)
|
||||||
|
|
||||||
|
|
||||||
|
def spectral_normalize_torch(magnitudes, norm_fn):
|
||||||
|
output = dynamic_range_compression_torch(magnitudes, norm_fn=norm_fn)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class MelConverter(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
sampling_rate: float,
|
||||||
|
n_fft: int,
|
||||||
|
num_mels: int,
|
||||||
|
hop_size: int,
|
||||||
|
win_size: int,
|
||||||
|
fmin: float,
|
||||||
|
fmax: float,
|
||||||
|
norm_fn,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.sampling_rate = sampling_rate
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.num_mels = num_mels
|
||||||
|
self.hop_size = hop_size
|
||||||
|
self.win_size = win_size
|
||||||
|
self.fmin = fmin
|
||||||
|
self.fmax = fmax
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
# mel = librosa_mel_fn(sr=self.sampling_rate,
|
||||||
|
# n_fft=self.n_fft,
|
||||||
|
# n_mels=self.num_mels,
|
||||||
|
# fmin=self.fmin,
|
||||||
|
# fmax=self.fmax)
|
||||||
|
# mel_basis = torch.from_numpy(mel).float()
|
||||||
|
mel_basis = torch.empty((num_mels, 1 + n_fft // 2))
|
||||||
|
hann_window = torch.hann_window(self.win_size)
|
||||||
|
|
||||||
|
self.register_buffer('mel_basis', mel_basis)
|
||||||
|
self.register_buffer('hann_window', hann_window)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.mel_basis.device
|
||||||
|
|
||||||
|
def forward(self, waveform: torch.Tensor, center: bool = False) -> torch.Tensor:
|
||||||
|
waveform = waveform.clamp(min=-1., max=1.).to(self.device)
|
||||||
|
|
||||||
|
waveform = torch.nn.functional.pad(
|
||||||
|
waveform.unsqueeze(1),
|
||||||
|
[int((self.n_fft - self.hop_size) / 2),
|
||||||
|
int((self.n_fft - self.hop_size) / 2)],
|
||||||
|
mode='reflect')
|
||||||
|
waveform = waveform.squeeze(1)
|
||||||
|
|
||||||
|
spec = torch.stft(waveform,
|
||||||
|
self.n_fft,
|
||||||
|
hop_length=self.hop_size,
|
||||||
|
win_length=self.win_size,
|
||||||
|
window=self.hann_window,
|
||||||
|
center=center,
|
||||||
|
pad_mode='reflect',
|
||||||
|
normalized=False,
|
||||||
|
onesided=True,
|
||||||
|
return_complex=True)
|
||||||
|
|
||||||
|
spec = torch.view_as_real(spec)
|
||||||
|
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
||||||
|
spec = torch.matmul(self.mel_basis, spec)
|
||||||
|
spec = spectral_normalize_torch(spec, self.norm_fn)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
class AudioAutoencoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
# ckpt_path: str,
|
||||||
|
mode=Literal['16k', '44k'],
|
||||||
|
need_vae_encoder: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
assert mode == "16k", "Only 16k mode is supported currently."
|
||||||
|
self.mel_converter = MelConverter(sampling_rate=16_000,
|
||||||
|
n_fft=1024,
|
||||||
|
num_mels=80,
|
||||||
|
hop_size=256,
|
||||||
|
win_size=1024,
|
||||||
|
fmin=0,
|
||||||
|
fmax=8_000,
|
||||||
|
norm_fn=torch.log10)
|
||||||
|
|
||||||
|
self.vae = VAE_16k().eval()
|
||||||
|
|
||||||
|
bigvgan_config = {
|
||||||
|
"resblock": "1",
|
||||||
|
"num_mels": 80,
|
||||||
|
"upsample_rates": [4, 4, 2, 2, 2, 2],
|
||||||
|
"upsample_kernel_sizes": [8, 8, 4, 4, 4, 4],
|
||||||
|
"upsample_initial_channel": 1536,
|
||||||
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
|
"resblock_dilation_sizes": [
|
||||||
|
[1, 3, 5],
|
||||||
|
[1, 3, 5],
|
||||||
|
[1, 3, 5],
|
||||||
|
],
|
||||||
|
"activation": "snakebeta",
|
||||||
|
"snake_logscale": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.vocoder = BigVGANVocoder(
|
||||||
|
bigvgan_config
|
||||||
|
).eval()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def encode_audio(self, x) -> DiagonalGaussianDistribution:
|
||||||
|
# x: (B * L)
|
||||||
|
mel = self.mel_converter(x)
|
||||||
|
dist = self.vae.encode(mel)
|
||||||
|
|
||||||
|
return dist
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def decode(self, z):
|
||||||
|
mel_decoded = self.vae.decode(z)
|
||||||
|
audio = self.vocoder(mel_decoded)
|
||||||
|
|
||||||
|
audio = torchaudio.functional.resample(audio, 16000, 44100)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def encode(self, audio):
|
||||||
|
audio = audio.mean(dim=1)
|
||||||
|
audio = torchaudio.functional.resample(audio, 44100, 16000)
|
||||||
|
dist = self.encode_audio(audio)
|
||||||
|
return dist.mean
|
||||||
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
219
comfy/ldm/mmaudio/vae/bigvgan.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
# Copyright (c) 2022 NVIDIA CORPORATION.
|
||||||
|
# Licensed under the MIT license.
|
||||||
|
|
||||||
|
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
|
||||||
|
# LICENSE is in incl_licenses directory.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from . import activations
|
||||||
|
from .alias_free_torch import Activation1d
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def get_padding(kernel_size, dilation=1):
|
||||||
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5), activation=None):
|
||||||
|
super(AMPBlock1, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs1 = nn.ModuleList([
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0])),
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1])),
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]))
|
||||||
|
])
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList([
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)),
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1)),
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1))
|
||||||
|
])
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
|
||||||
|
|
||||||
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
acts1, acts2 = self.activations[::2], self.activations[1::2]
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock2(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3), activation=None):
|
||||||
|
super(AMPBlock2, self).__init__()
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList([
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0])),
|
||||||
|
ops.Conv1d(channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]))
|
||||||
|
])
|
||||||
|
|
||||||
|
self.num_layers = len(self.convs) # total number of conv layers
|
||||||
|
|
||||||
|
if activation == 'snake': # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
elif activation == 'snakebeta': # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
self.activations = nn.ModuleList([
|
||||||
|
Activation1d(
|
||||||
|
activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||||
|
for _ in range(self.num_layers)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c, a in zip(self.convs, self.activations):
|
||||||
|
xt = a(x)
|
||||||
|
xt = c(xt)
|
||||||
|
x = xt + x
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BigVGANVocoder(torch.nn.Module):
|
||||||
|
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
|
||||||
|
def __init__(self, h):
|
||||||
|
super().__init__()
|
||||||
|
if isinstance(h, dict):
|
||||||
|
h = SimpleNamespace(**h)
|
||||||
|
self.h = h
|
||||||
|
|
||||||
|
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||||
|
self.num_upsamples = len(h.upsample_rates)
|
||||||
|
|
||||||
|
# pre conv
|
||||||
|
self.conv_pre = ops.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||||
|
|
||||||
|
# define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||||
|
resblock = AMPBlock1 if h.resblock == '1' else AMPBlock2
|
||||||
|
|
||||||
|
# transposed conv-based upsamplers. does not apply anti-aliasing
|
||||||
|
self.ups = nn.ModuleList()
|
||||||
|
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||||
|
self.ups.append(
|
||||||
|
nn.ModuleList([
|
||||||
|
ops.ConvTranspose1d(h.upsample_initial_channel // (2**i),
|
||||||
|
h.upsample_initial_channel // (2**(i + 1)),
|
||||||
|
k,
|
||||||
|
u,
|
||||||
|
padding=(k - u) // 2)
|
||||||
|
]))
|
||||||
|
|
||||||
|
# residual blocks using anti-aliased multi-periodicity composition modules (AMP)
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(len(self.ups)):
|
||||||
|
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||||
|
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||||
|
self.resblocks.append(resblock(h, ch, k, d, activation=h.activation))
|
||||||
|
|
||||||
|
# post conv
|
||||||
|
if h.activation == "snake": # periodic nonlinearity with snake function and anti-aliasing
|
||||||
|
activation_post = activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
elif h.activation == "snakebeta": # periodic nonlinearity with snakebeta function and anti-aliasing
|
||||||
|
activation_post = activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||||
|
self.activation_post = Activation1d(activation=activation_post)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"activation incorrectly specified. check the config file and look for 'activation'."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_post = ops.Conv1d(ch, 1, 7, 1, padding=3)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# pre conv
|
||||||
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
|
for i in range(self.num_upsamples):
|
||||||
|
# upsampling
|
||||||
|
for i_up in range(len(self.ups[i])):
|
||||||
|
x = self.ups[i][i_up](x)
|
||||||
|
# AMP blocks
|
||||||
|
xs = None
|
||||||
|
for j in range(self.num_kernels):
|
||||||
|
if xs is None:
|
||||||
|
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
else:
|
||||||
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
|
x = xs / self.num_kernels
|
||||||
|
|
||||||
|
# post conv
|
||||||
|
x = self.activation_post(x)
|
||||||
|
x = self.conv_post(x)
|
||||||
|
x = torch.tanh(x)
|
||||||
|
|
||||||
|
return x
|
||||||
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
92
comfy/ldm/mmaudio/vae/distributions.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractDistribution:
|
||||||
|
def sample(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DiracDistribution(AbstractDistribution):
|
||||||
|
def __init__(self, value):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class DiagonalGaussianDistribution(object):
|
||||||
|
def __init__(self, parameters, deterministic=False):
|
||||||
|
self.parameters = parameters
|
||||||
|
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||||
|
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||||
|
self.deterministic = deterministic
|
||||||
|
self.std = torch.exp(0.5 * self.logvar)
|
||||||
|
self.var = torch.exp(self.logvar)
|
||||||
|
if self.deterministic:
|
||||||
|
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||||
|
|
||||||
|
def sample(self):
|
||||||
|
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def kl(self, other=None):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
else:
|
||||||
|
if other is None:
|
||||||
|
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||||
|
+ self.var - 1.0 - self.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
else:
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
torch.pow(self.mean - other.mean, 2) / other.var
|
||||||
|
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||||
|
dim=[1, 2, 3])
|
||||||
|
|
||||||
|
def nll(self, sample, dims=[1,2,3]):
|
||||||
|
if self.deterministic:
|
||||||
|
return torch.Tensor([0.])
|
||||||
|
logtwopi = np.log(2.0 * np.pi)
|
||||||
|
return 0.5 * torch.sum(
|
||||||
|
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||||
|
dim=dims)
|
||||||
|
|
||||||
|
def mode(self):
|
||||||
|
return self.mean
|
||||||
|
|
||||||
|
|
||||||
|
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||||
|
"""
|
||||||
|
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||||
|
Compute the KL divergence between two gaussians.
|
||||||
|
Shapes are automatically broadcasted, so batches can be compared to
|
||||||
|
scalars, among other use cases.
|
||||||
|
"""
|
||||||
|
tensor = None
|
||||||
|
for obj in (mean1, logvar1, mean2, logvar2):
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
tensor = obj
|
||||||
|
break
|
||||||
|
assert tensor is not None, "at least one argument must be a Tensor"
|
||||||
|
|
||||||
|
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||||
|
# Tensors, but it does not work for torch.exp().
|
||||||
|
logvar1, logvar2 = [
|
||||||
|
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||||
|
for x in (logvar1, logvar2)
|
||||||
|
]
|
||||||
|
|
||||||
|
return 0.5 * (
|
||||||
|
-1.0
|
||||||
|
+ logvar2
|
||||||
|
- logvar1
|
||||||
|
+ torch.exp(logvar1 - logvar2)
|
||||||
|
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||||
|
)
|
||||||
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
358
comfy/ldm/mmaudio/vae/vae.py
Normal file
@ -0,0 +1,358 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
|
||||||
|
Upsample1D, nonlinearity)
|
||||||
|
from .distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
log = logging.getLogger()
|
||||||
|
|
||||||
|
DATA_MEAN_80D = [
|
||||||
|
-1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
|
||||||
|
-1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
|
||||||
|
-1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
|
||||||
|
-1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
|
||||||
|
-1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
|
||||||
|
-1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
|
||||||
|
-2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
|
||||||
|
-2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_STD_80D = [
|
||||||
|
1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
|
||||||
|
0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
|
||||||
|
0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
|
||||||
|
0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
|
||||||
|
0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
|
||||||
|
0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
|
||||||
|
1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_MEAN_128D = [
|
||||||
|
-3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
|
||||||
|
-2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
|
||||||
|
-2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
|
||||||
|
-3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
|
||||||
|
-3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
|
||||||
|
-3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
|
||||||
|
-3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
|
||||||
|
-4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
|
||||||
|
-4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
|
||||||
|
-5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
|
||||||
|
-6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
|
||||||
|
-7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
|
||||||
|
-9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
|
||||||
|
]
|
||||||
|
|
||||||
|
DATA_STD_128D = [
|
||||||
|
2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
|
||||||
|
2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
|
||||||
|
2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
|
||||||
|
2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
|
||||||
|
2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
|
||||||
|
2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
|
||||||
|
2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
|
||||||
|
2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
|
||||||
|
2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
|
||||||
|
2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
|
||||||
|
2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class VAE(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
data_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
hidden_dim: int,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if data_dim == 80:
|
||||||
|
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
|
||||||
|
self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
|
||||||
|
elif data_dim == 128:
|
||||||
|
self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
|
||||||
|
self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
|
||||||
|
|
||||||
|
self.data_mean = self.data_mean.view(1, -1, 1)
|
||||||
|
self.data_std = self.data_std.view(1, -1, 1)
|
||||||
|
|
||||||
|
self.encoder = Encoder1D(
|
||||||
|
dim=hidden_dim,
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_layers=[3],
|
||||||
|
down_layers=[0],
|
||||||
|
in_dim=data_dim,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
self.decoder = Decoder1D(
|
||||||
|
dim=hidden_dim,
|
||||||
|
ch_mult=(1, 2, 4),
|
||||||
|
num_res_blocks=2,
|
||||||
|
attn_layers=[3],
|
||||||
|
down_layers=[0],
|
||||||
|
in_dim=data_dim,
|
||||||
|
out_dim=data_dim,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
# self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
|
||||||
|
# self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
|
||||||
|
|
||||||
|
self.initialize_weights()
|
||||||
|
|
||||||
|
def initialize_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
|
||||||
|
if normalize:
|
||||||
|
x = self.normalize(x)
|
||||||
|
moments = self.encoder(x)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
|
||||||
|
dec = self.decoder(z)
|
||||||
|
if unnormalize:
|
||||||
|
dec = self.unnormalize(dec)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return (x - comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)) / comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return x * comfy.model_management.cast_to(self.data_std, dtype=x.dtype, device=x.device) + comfy.model_management.cast_to(self.data_mean, dtype=x.dtype, device=x.device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
sample_posterior: bool = True,
|
||||||
|
rng: Optional[torch.Generator] = None,
|
||||||
|
normalize: bool = True,
|
||||||
|
unnormalize: bool = True,
|
||||||
|
) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
|
||||||
|
|
||||||
|
posterior = self.encode(x, normalize=normalize)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample(rng)
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z, unnormalize=unnormalize)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def load_weights(self, src_dict) -> None:
|
||||||
|
self.load_state_dict(src_dict, strict=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return next(self.parameters()).device
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def remove_weight_norm(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
dim: int,
|
||||||
|
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||||
|
num_res_blocks: int,
|
||||||
|
attn_layers: list[int] = [],
|
||||||
|
down_layers: list[int] = [],
|
||||||
|
resamp_with_conv: bool = True,
|
||||||
|
in_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
double_z: bool = True,
|
||||||
|
kernel_size: int = 3,
|
||||||
|
clip_act: float = 256.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_layers = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.in_channels = in_dim
|
||||||
|
self.clip_act = clip_act
|
||||||
|
self.down_layers = down_layers
|
||||||
|
self.attn_layers = attn_layers
|
||||||
|
self.conv_in = ops.Conv1d(in_dim, self.dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
|
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||||
|
self.in_ch_mult = in_ch_mult
|
||||||
|
# downsampling
|
||||||
|
self.down = nn.ModuleList()
|
||||||
|
for i_level in range(self.num_layers):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_in = dim * in_ch_mult[i_level]
|
||||||
|
block_out = dim * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
block.append(
|
||||||
|
ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_out,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True))
|
||||||
|
block_in = block_out
|
||||||
|
if i_level in attn_layers:
|
||||||
|
attn.append(AttnBlock1D(block_in))
|
||||||
|
down = nn.Module()
|
||||||
|
down.block = block
|
||||||
|
down.attn = attn
|
||||||
|
if i_level in down_layers:
|
||||||
|
down.downsample = Downsample1D(block_in, resamp_with_conv)
|
||||||
|
self.down.append(down)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_in,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True)
|
||||||
|
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
|
||||||
|
out_dim=block_in,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
use_norm=True)
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.conv_out = ops.Conv1d(block_in,
|
||||||
|
2 * embed_dim if double_z else embed_dim,
|
||||||
|
kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
|
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
# downsampling
|
||||||
|
h = self.conv_in(x)
|
||||||
|
for i_level in range(self.num_layers):
|
||||||
|
for i_block in range(self.num_res_blocks):
|
||||||
|
h = self.down[i_level].block[i_block](h)
|
||||||
|
if len(self.down[i_level].attn) > 0:
|
||||||
|
h = self.down[i_level].attn[i_block](h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
h = self.down[i_level].downsample(h)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
*,
|
||||||
|
dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
ch_mult: tuple[int] = (1, 2, 4, 8),
|
||||||
|
num_res_blocks: int,
|
||||||
|
attn_layers: list[int] = [],
|
||||||
|
down_layers: list[int] = [],
|
||||||
|
kernel_size: int = 3,
|
||||||
|
resamp_with_conv: bool = True,
|
||||||
|
in_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
clip_act: float = 256.0):
|
||||||
|
super().__init__()
|
||||||
|
self.ch = dim
|
||||||
|
self.num_layers = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.in_channels = in_dim
|
||||||
|
self.clip_act = clip_act
|
||||||
|
self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
block_in = dim * ch_mult[self.num_layers - 1]
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
self.conv_in = ops.Conv1d(embed_dim, block_in, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||||
|
self.mid.attn_1 = AttnBlock1D(block_in)
|
||||||
|
self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_layers)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = dim * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
|
||||||
|
block_in = block_out
|
||||||
|
if i_level in attn_layers:
|
||||||
|
attn.append(AttnBlock1D(block_in))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
up.upsample = Upsample1D(block_in, resamp_with_conv)
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.conv_out = ops.Conv1d(block_in, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
self.learnable_gain = nn.Parameter(torch.zeros([]))
|
||||||
|
|
||||||
|
def forward(self, z):
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h)
|
||||||
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_layers)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h)
|
||||||
|
h = h.clamp(-self.clip_act, self.clip_act)
|
||||||
|
if i_level in self.down_layers:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h) * (self.learnable_gain + 1)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
def VAE_16k(**kwargs) -> VAE:
|
||||||
|
return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def VAE_44k(**kwargs) -> VAE:
|
||||||
|
return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_my_vae(name: str, **kwargs) -> VAE:
|
||||||
|
if name == '16k':
|
||||||
|
return VAE_16k(**kwargs)
|
||||||
|
if name == '44k':
|
||||||
|
return VAE_44k(**kwargs)
|
||||||
|
raise ValueError(f'Unknown model: {name}')
|
||||||
|
|
||||||
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
121
comfy/ldm/mmaudio/vae/vae_modules.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||||
|
import math
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def nonlinearity(x):
|
||||||
|
# swish
|
||||||
|
return torch.nn.functional.silu(x) / 0.596
|
||||||
|
|
||||||
|
def mp_sum(a, b, t=0.5):
|
||||||
|
return a.lerp(b, t) / math.sqrt((1 - t)**2 + t**2)
|
||||||
|
|
||||||
|
def normalize(x, dim=None, eps=1e-4):
|
||||||
|
if dim is None:
|
||||||
|
dim = list(range(1, x.ndim))
|
||||||
|
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
|
||||||
|
norm = torch.add(eps, norm, alpha=math.sqrt(norm.numel() / x.numel()))
|
||||||
|
return x / norm.to(x.dtype)
|
||||||
|
|
||||||
|
class ResnetBlock1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
|
||||||
|
super().__init__()
|
||||||
|
self.in_dim = in_dim
|
||||||
|
out_dim = in_dim if out_dim is None else out_dim
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.use_conv_shortcut = conv_shortcut
|
||||||
|
self.use_norm = use_norm
|
||||||
|
|
||||||
|
self.conv1 = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
self.conv2 = ops.Conv1d(out_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
if self.in_dim != self.out_dim:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
self.conv_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
|
||||||
|
else:
|
||||||
|
self.nin_shortcut = ops.Conv1d(in_dim, out_dim, kernel_size=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
# pixel norm
|
||||||
|
if self.use_norm:
|
||||||
|
x = normalize(x, dim=1)
|
||||||
|
|
||||||
|
h = x
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv1(h)
|
||||||
|
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv2(h)
|
||||||
|
|
||||||
|
if self.in_dim != self.out_dim:
|
||||||
|
if self.use_conv_shortcut:
|
||||||
|
x = self.conv_shortcut(x)
|
||||||
|
else:
|
||||||
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
|
return mp_sum(x, h, t=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class AttnBlock1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, num_heads=1):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.qkv = ops.Conv1d(in_channels, in_channels * 3, kernel_size=1, padding=0, bias=False)
|
||||||
|
self.proj_out = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||||
|
self.optimized_attention = vae_attention()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
h = x
|
||||||
|
y = self.qkv(h)
|
||||||
|
y = y.reshape(y.shape[0], -1, 3, y.shape[-1])
|
||||||
|
q, k, v = normalize(y, dim=1).unbind(2)
|
||||||
|
|
||||||
|
h = self.optimized_attention(q, k, v)
|
||||||
|
h = self.proj_out(h)
|
||||||
|
|
||||||
|
return mp_sum(x, h, t=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
class Upsample1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
self.conv = ops.Conv1d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Downsample1D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_channels, with_conv):
|
||||||
|
super().__init__()
|
||||||
|
self.with_conv = with_conv
|
||||||
|
if self.with_conv:
|
||||||
|
# no asymmetric padding in torch conv, must do it ourselves
|
||||||
|
self.conv1 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||||
|
self.conv2 = ops.Conv1d(in_channels, in_channels, kernel_size=1, padding=0, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv1(x)
|
||||||
|
|
||||||
|
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
if self.with_conv:
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint)
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from comfy.ldm.modules.attention import optimized_attention_masked
|
|||||||
from comfy.ldm.flux.layers import EmbedND
|
from comfy.ldm.flux.layers import EmbedND
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
import comfy.patcher_extension
|
import comfy.patcher_extension
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
|
||||||
class GELU(nn.Module):
|
class GELU(nn.Module):
|
||||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
|
||||||
@ -134,33 +135,34 @@ class Attention(nn.Module):
|
|||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
batch_size = hidden_states.shape[0]
|
||||||
|
seq_img = hidden_states.shape[1]
|
||||||
seq_txt = encoder_hidden_states.shape[1]
|
seq_txt = encoder_hidden_states.shape[1]
|
||||||
|
|
||||||
img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
|
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||||
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
|
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
|
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
|
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
|
||||||
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
|
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)
|
||||||
|
|
||||||
img_query = self.norm_q(img_query)
|
img_query = self.norm_q(img_query)
|
||||||
img_key = self.norm_k(img_key)
|
img_key = self.norm_k(img_key)
|
||||||
txt_query = self.norm_added_q(txt_query)
|
txt_query = self.norm_added_q(txt_query)
|
||||||
txt_key = self.norm_added_k(txt_key)
|
txt_key = self.norm_added_k(txt_key)
|
||||||
|
|
||||||
joint_query = torch.cat([txt_query, img_query], dim=1)
|
joint_query = torch.cat([txt_query, img_query], dim=2)
|
||||||
joint_key = torch.cat([txt_key, img_key], dim=1)
|
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||||
joint_value = torch.cat([txt_value, img_value], dim=1)
|
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||||
|
|
||||||
joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
|
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||||
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
|
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||||
|
|
||||||
joint_query = joint_query.flatten(start_dim=2)
|
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||||
joint_key = joint_key.flatten(start_dim=2)
|
attention_mask, transformer_options=transformer_options,
|
||||||
joint_value = joint_value.flatten(start_dim=2)
|
skip_reshape=True)
|
||||||
|
|
||||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
|
|
||||||
|
|
||||||
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
|
||||||
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
img_attn_output = joint_hidden_states[:, seq_txt:, :]
|
||||||
@ -413,7 +415,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
|
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
del ids, txt_ids, img_ids
|
del ids, txt_ids, img_ids
|
||||||
|
|
||||||
hidden_states = self.img_in(hidden_states)
|
hidden_states = self.img_in(hidden_states)
|
||||||
|
|||||||
@ -232,6 +232,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
|
x = x.contiguous() # otherwise implicit in LayerNorm
|
||||||
y = self.self_attn(
|
y = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
@ -588,7 +589,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
@ -601,10 +602,22 @@ class WanModel(torch.nn.Module):
|
|||||||
if steps_w is None:
|
if steps_w is None:
|
||||||
steps_w = w_len
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
@ -630,7 +643,7 @@ class WanModel(torch.nn.Module):
|
|||||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
t_len += 1
|
t_len += 1
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
|
|||||||
@ -657,51 +657,51 @@ class WanVAE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x):
|
||||||
self.clear_cache()
|
conv_idx = [0]
|
||||||
|
feat_map = [None] * count_conv3d(self.encoder)
|
||||||
x = patchify(x, patch_size=2)
|
x = patchify(x, patch_size=2)
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
iter_ = 1 + (t - 1) // 4
|
iter_ = 1 + (t - 1) // 4
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(
|
out = self.encoder(
|
||||||
x[:, :, :1, :, :],
|
x[:, :, :1, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._enc_conv_idx,
|
feat_idx=conv_idx,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_ = self.encoder(
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._enc_conv_idx,
|
feat_idx=conv_idx,
|
||||||
)
|
)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
self.clear_cache()
|
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
self.clear_cache()
|
conv_idx = [0]
|
||||||
|
feat_map = [None] * count_conv3d(self.decoder)
|
||||||
iter_ = z.shape[2]
|
iter_ = z.shape[2]
|
||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(
|
out = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._conv_idx,
|
feat_idx=conv_idx,
|
||||||
first_chunk=True,
|
first_chunk=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_ = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=self._conv_idx,
|
feat_idx=conv_idx,
|
||||||
)
|
)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
out = unpatchify(out, patch_size=2)
|
out = unpatchify(out, patch_size=2)
|
||||||
self.clear_cache()
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def reparameterize(self, mu, log_var):
|
def reparameterize(self, mu, log_var):
|
||||||
@ -715,12 +715,3 @@ class WanVAE(nn.Module):
|
|||||||
return mu
|
return mu
|
||||||
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
||||||
return mu + std * torch.randn_like(std)
|
return mu + std * torch.randn_like(std)
|
||||||
|
|
||||||
def clear_cache(self):
|
|
||||||
self._conv_num = count_conv3d(self.decoder)
|
|
||||||
self._conv_idx = [0]
|
|
||||||
self._feat_map = [None] * self._conv_num
|
|
||||||
# cache encode
|
|
||||||
self._enc_conv_num = count_conv3d(self.encoder)
|
|
||||||
self._enc_conv_idx = [0]
|
|
||||||
self._enc_feat_map = [None] * self._enc_conv_num
|
|
||||||
|
|||||||
@ -134,10 +134,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if model_config.custom_operations is None:
|
if model_config.custom_operations is None:
|
||||||
fp8 = model_config.optimizations.get("fp8", False)
|
fp8 = model_config.optimizations.get("fp8", False)
|
||||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
|
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||||
else:
|
else:
|
||||||
operations = model_config.custom_operations
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
|
self.diffusion_model.eval()
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
logging.debug("using channels last mode for diffusion model")
|
logging.debug("using channels last mode for diffusion model")
|
||||||
@ -196,8 +197,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
|
|
||||||
t = self.process_timestep(t, x=x, **extra_conds)
|
t = self.process_timestep(t, x=x, **extra_conds)
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
if "latent_shapes" in extra_conds:
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||||
|
|
||||||
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||||
|
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||||
|
model_output, _ = utils.pack_latents(model_output)
|
||||||
|
|
||||||
|
return self.model_sampling.calculate_denoised(sigma, model_output.float(), x)
|
||||||
|
|
||||||
def process_timestep(self, timestep, **kwargs):
|
def process_timestep(self, timestep, **kwargs):
|
||||||
return timestep
|
return timestep
|
||||||
@ -326,6 +333,14 @@ class BaseModel(torch.nn.Module):
|
|||||||
if self.model_config.scaled_fp8 is not None:
|
if self.model_config.scaled_fp8 is not None:
|
||||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||||
|
|
||||||
|
# Save mixed precision metadata
|
||||||
|
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||||
|
metadata = {
|
||||||
|
"format_version": "1.0",
|
||||||
|
"layers": self.model_config.layer_quant_config
|
||||||
|
}
|
||||||
|
unet_state_dict["_quantization_metadata"] = metadata
|
||||||
|
|
||||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||||
|
|
||||||
if self.model_type == ModelType.V_PREDICTION:
|
if self.model_type == ModelType.V_PREDICTION:
|
||||||
@ -669,7 +684,6 @@ class Lotus(BaseModel):
|
|||||||
class StableCascade_C(BaseModel):
|
class StableCascade_C(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
super().__init__(model_config, model_type, device=device, unet_model=StageC)
|
||||||
self.diffusion_model.eval().requires_grad_(False)
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
@ -698,7 +712,6 @@ class StableCascade_C(BaseModel):
|
|||||||
class StableCascade_B(BaseModel):
|
class StableCascade_B(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
super().__init__(model_config, model_type, device=device, unet_model=StageB)
|
||||||
self.diffusion_model.eval().requires_grad_(False)
|
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
|
|||||||
@ -6,6 +6,20 @@ import math
|
|||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def detect_layer_quantization(metadata):
|
||||||
|
quant_key = "_quantization_metadata"
|
||||||
|
if metadata is not None and quant_key in metadata:
|
||||||
|
quant_metadata = metadata.pop(quant_key)
|
||||||
|
quant_metadata = json.loads(quant_metadata)
|
||||||
|
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||||
|
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||||
|
return quant_metadata["layers"]
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid quantization metadata format")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def count_blocks(state_dict_keys, prefix_string):
|
def count_blocks(state_dict_keys, prefix_string):
|
||||||
count = 0
|
count = 0
|
||||||
while True:
|
while True:
|
||||||
@ -213,7 +227,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["nerf_mlp_ratio"] = 4
|
dit_config["nerf_mlp_ratio"] = 4
|
||||||
dit_config["nerf_depth"] = 4
|
dit_config["nerf_depth"] = 4
|
||||||
dit_config["nerf_max_freqs"] = 8
|
dit_config["nerf_max_freqs"] = 8
|
||||||
dit_config["nerf_tile_size"] = 32
|
dit_config["nerf_tile_size"] = 512
|
||||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||||
else:
|
else:
|
||||||
@ -701,6 +715,12 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
|||||||
else:
|
else:
|
||||||
model_config.optimizations["fp8"] = True
|
model_config.optimizations["fp8"] = True
|
||||||
|
|
||||||
|
# Detect per-layer quantization (mixed precision)
|
||||||
|
layer_quant_config = detect_layer_quantization(metadata)
|
||||||
|
if layer_quant_config:
|
||||||
|
model_config.layer_quant_config = layer_quant_config
|
||||||
|
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||||
|
|
||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def unet_prefix_from_state_dict(state_dict):
|
def unet_prefix_from_state_dict(state_dict):
|
||||||
|
|||||||
@ -89,6 +89,7 @@ if args.deterministic:
|
|||||||
|
|
||||||
directml_enabled = False
|
directml_enabled = False
|
||||||
if args.directml is not None:
|
if args.directml is not None:
|
||||||
|
logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.")
|
||||||
import torch_directml
|
import torch_directml
|
||||||
directml_enabled = True
|
directml_enabled = True
|
||||||
device_index = args.directml
|
device_index = args.directml
|
||||||
@ -330,13 +331,21 @@ except:
|
|||||||
|
|
||||||
|
|
||||||
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
SUPPORT_FP8_OPS = args.supports_fp8_compute
|
||||||
|
|
||||||
|
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_amd():
|
if is_amd():
|
||||||
|
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
||||||
|
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
|
||||||
|
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
|
||||||
|
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
|
||||||
except:
|
except:
|
||||||
rocm_version = (6, -1)
|
rocm_version = (6, -1)
|
||||||
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
|
|
||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
@ -344,11 +353,11 @@ try:
|
|||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
# if torch_version_numeric >= (2, 8):
|
if rocm_version >= (7, 0):
|
||||||
# if any((a in arch) for a in ["gfx1201"]):
|
if any((a in arch) for a in ["gfx1201"]):
|
||||||
# ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||||
SUPPORT_FP8_OPS = True
|
SUPPORT_FP8_OPS = True
|
||||||
|
|
||||||
except:
|
except:
|
||||||
@ -370,6 +379,9 @@ try:
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch_version_numeric >= (2, 5):
|
if torch_version_numeric >= (2, 5):
|
||||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||||
@ -925,11 +937,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
if d == torch.float16 and should_use_fp16(device):
|
if d == torch.float16 and should_use_fp16(device):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
if d == torch.bfloat16 and should_use_bf16(device):
|
||||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
|
||||||
# also a problem on RDNA4 except fp32 is also slow there.
|
|
||||||
# This is due to large bf16 convolutions being extremely slow.
|
|
||||||
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
@ -991,12 +999,6 @@ def device_supports_non_blocking(device):
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def device_should_use_non_blocking(device):
|
|
||||||
if not device_supports_non_blocking(device):
|
|
||||||
return False
|
|
||||||
return False
|
|
||||||
# return True #TODO: figure out why this causes memory issues on Nvidia and possibly others
|
|
||||||
|
|
||||||
def force_channels_last():
|
def force_channels_last():
|
||||||
if args.force_channels_last:
|
if args.force_channels_last:
|
||||||
return True
|
return True
|
||||||
@ -1011,6 +1013,16 @@ if args.async_offload:
|
|||||||
NUM_STREAMS = 2
|
NUM_STREAMS = 2
|
||||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||||
|
|
||||||
|
def current_stream(device):
|
||||||
|
if device is None:
|
||||||
|
return None
|
||||||
|
if is_device_cuda(device):
|
||||||
|
return torch.cuda.current_stream()
|
||||||
|
elif is_device_xpu(device):
|
||||||
|
return torch.xpu.current_stream()
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
stream_counters = {}
|
stream_counters = {}
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
stream_counter = stream_counters.get(device, 0)
|
stream_counter = stream_counters.get(device, 0)
|
||||||
@ -1019,21 +1031,17 @@ def get_offload_stream(device):
|
|||||||
|
|
||||||
if device in STREAMS:
|
if device in STREAMS:
|
||||||
ss = STREAMS[device]
|
ss = STREAMS[device]
|
||||||
s = ss[stream_counter]
|
#Sync the oldest stream in the queue with the current
|
||||||
|
ss[stream_counter].wait_stream(current_stream(device))
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
stream_counter = (stream_counter + 1) % len(ss)
|
||||||
if is_device_cuda(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.cuda.current_stream())
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
ss[stream_counter].wait_stream(torch.xpu.current_stream())
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return ss[stream_counter]
|
||||||
elif is_device_cuda(device):
|
elif is_device_cuda(device):
|
||||||
ss = []
|
ss = []
|
||||||
for k in range(NUM_STREAMS):
|
for k in range(NUM_STREAMS):
|
||||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
elif is_device_xpu(device):
|
elif is_device_xpu(device):
|
||||||
@ -1042,18 +1050,14 @@ def get_offload_stream(device):
|
|||||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||||
STREAMS[device] = ss
|
STREAMS[device] = ss
|
||||||
s = ss[stream_counter]
|
s = ss[stream_counter]
|
||||||
stream_counter = (stream_counter + 1) % len(ss)
|
|
||||||
stream_counters[device] = stream_counter
|
stream_counters[device] = stream_counter
|
||||||
return s
|
return s
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def sync_stream(device, stream):
|
def sync_stream(device, stream):
|
||||||
if stream is None:
|
if stream is None or current_stream(device) is None:
|
||||||
return
|
return
|
||||||
if is_device_cuda(device):
|
current_stream(device).wait_stream(stream)
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
|
||||||
elif is_device_xpu(device):
|
|
||||||
torch.xpu.current_stream().wait_stream(stream)
|
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
|
||||||
if device is None or weight.device == device:
|
if device is None or weight.device == device:
|
||||||
@ -1078,6 +1082,73 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
|||||||
non_blocking = device_supports_non_blocking(device)
|
non_blocking = device_supports_non_blocking(device)
|
||||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
|
PINNED_MEMORY = {}
|
||||||
|
TOTAL_PINNED_MEMORY = 0
|
||||||
|
MAX_PINNED_MEMORY = -1
|
||||||
|
if not args.disable_pinned_memory:
|
||||||
|
if is_nvidia() or is_amd():
|
||||||
|
if WINDOWS:
|
||||||
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50%
|
||||||
|
else:
|
||||||
|
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||||
|
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||||
|
|
||||||
|
|
||||||
|
def pin_memory(tensor):
|
||||||
|
global TOTAL_PINNED_MEMORY
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if tensor.is_pinned():
|
||||||
|
#NOTE: Cuda does detect when a tensor is already pinned and would
|
||||||
|
#error below, but there are proven cases where this also queues an error
|
||||||
|
#on the GPU async. So dont trust the CUDA API and guard here
|
||||||
|
return False
|
||||||
|
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ptr = tensor.data_ptr()
|
||||||
|
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||||
|
PINNED_MEMORY[ptr] = size
|
||||||
|
TOTAL_PINNED_MEMORY += size
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def unpin_memory(tensor):
|
||||||
|
global TOTAL_PINNED_MEMORY
|
||||||
|
if MAX_PINNED_MEMORY <= 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not is_device_cpu(tensor.device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
ptr = tensor.data_ptr()
|
||||||
|
size = tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
|
size_stored = PINNED_MEMORY.get(ptr, None)
|
||||||
|
if size_stored is None:
|
||||||
|
logging.warning("Tried to unpin tensor not pinned by ComfyUI")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if size != size_stored:
|
||||||
|
logging.warning("Size of pinned tensor changed")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
|
||||||
|
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
|
||||||
|
if len(PINNED_MEMORY) == 0:
|
||||||
|
TOTAL_PINNED_MEMORY = 0
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def sage_attention_enabled():
|
def sage_attention_enabled():
|
||||||
return args.use_sage_attention
|
return args.use_sage_attention
|
||||||
|
|
||||||
@ -1330,7 +1401,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
|
|
||||||
if is_amd():
|
if is_amd():
|
||||||
arch = torch.cuda.get_device_properties(device).gcnArchName
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
|
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
|
||||||
if manual_cast:
|
if manual_cast:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|||||||
@ -123,16 +123,30 @@ def move_weight_functions(m, device):
|
|||||||
return memory
|
return memory
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, patches):
|
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.patches = patches
|
self.patches = patches
|
||||||
|
self.convert_func = convert_func
|
||||||
|
self.set_func = set_func
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
intermediate_dtype = weight.dtype
|
intermediate_dtype = weight.dtype
|
||||||
|
if self.convert_func is not None:
|
||||||
|
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||||
|
|
||||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||||
intermediate_dtype = torch.float32
|
intermediate_dtype = torch.float32
|
||||||
return comfy.float.stochastic_rounding(comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype), weight.dtype, seed=string_to_seed(self.key))
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is None:
|
||||||
|
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||||
|
else:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||||
|
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||||
|
if self.set_func is not None:
|
||||||
|
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||||
|
else:
|
||||||
|
return out
|
||||||
|
|
||||||
def get_key_weight(model, key):
|
def get_key_weight(model, key):
|
||||||
set_func = None
|
set_func = None
|
||||||
@ -224,6 +238,7 @@ class ModelPatcher:
|
|||||||
self.force_cast_weights = False
|
self.force_cast_weights = False
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
self.parent = None
|
self.parent = None
|
||||||
|
self.pinned = set()
|
||||||
|
|
||||||
self.attachments: dict[str] = {}
|
self.attachments: dict[str] = {}
|
||||||
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
self.additional_models: dict[str, list[ModelPatcher]] = {}
|
||||||
@ -261,6 +276,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
return self.model.model_loaded_weight_memory
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
@ -280,6 +298,7 @@ class ModelPatcher:
|
|||||||
n.backup = self.backup
|
n.backup = self.backup
|
||||||
n.object_patches_backup = self.object_patches_backup
|
n.object_patches_backup = self.object_patches_backup
|
||||||
n.parent = self
|
n.parent = self
|
||||||
|
n.pinned = self.pinned
|
||||||
|
|
||||||
n.force_cast_weights = self.force_cast_weights
|
n.force_cast_weights = self.force_cast_weights
|
||||||
|
|
||||||
@ -436,6 +455,19 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
|
rope_options["scale_x"] = scale_x
|
||||||
|
rope_options["scale_y"] = scale_y
|
||||||
|
rope_options["scale_t"] = scale_t
|
||||||
|
|
||||||
|
rope_options["shift_x"] = shift_x
|
||||||
|
rope_options["shift_y"] = shift_y
|
||||||
|
rope_options["shift_t"] = shift_t
|
||||||
|
|
||||||
|
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||||
|
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
@ -604,6 +636,21 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
|
def pin_weight_to_device(self, key):
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
if comfy.model_management.pin_memory(weight):
|
||||||
|
self.pinned.add(key)
|
||||||
|
|
||||||
|
def unpin_weight(self, key):
|
||||||
|
if key in self.pinned:
|
||||||
|
weight, set_func, convert_func = get_key_weight(self.model, key)
|
||||||
|
comfy.model_management.unpin_memory(weight)
|
||||||
|
self.pinned.remove(key)
|
||||||
|
|
||||||
|
def unpin_all_weights(self):
|
||||||
|
for key in list(self.pinned):
|
||||||
|
self.unpin_weight(key)
|
||||||
|
|
||||||
def _load_list(self):
|
def _load_list(self):
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
@ -625,9 +672,11 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
|
lowvram_mem_counter = 0
|
||||||
loading = self._load_list()
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
|
offloaded = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
n = x[1]
|
n = x[1]
|
||||||
@ -644,6 +693,7 @@ class ModelPatcher:
|
|||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
|
lowvram_mem_counter += module_mem
|
||||||
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -657,16 +707,19 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = [LowVramPatch(weight_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function = [LowVramPatch(weight_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = [LowVramPatch(bias_key, self.patches)]
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function = [LowVramPatch(bias_key, self.patches, convert_func, set_func)]
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
offloaded.append((module_mem, n, m, params))
|
||||||
else:
|
else:
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
@ -697,7 +750,9 @@ class ModelPatcher:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for param in params:
|
for param in params:
|
||||||
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
|
key = "{}.{}".format(n, param)
|
||||||
|
self.unpin_weight(key)
|
||||||
|
self.patch_weight_to_device(key, device_to=device_to)
|
||||||
|
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
m.comfy_patched_weights = True
|
m.comfy_patched_weights = True
|
||||||
@ -705,11 +760,17 @@ class ModelPatcher:
|
|||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
x[2].to(device_to)
|
x[2].to(device_to)
|
||||||
|
|
||||||
|
for x in offloaded:
|
||||||
|
n = x[1]
|
||||||
|
params = x[3]
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||||
self.model.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
if full_load:
|
if full_load:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
@ -746,6 +807,7 @@ class ModelPatcher:
|
|||||||
self.eject_model()
|
self.eject_model()
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
self.unpatch_hooks()
|
self.unpatch_hooks()
|
||||||
|
self.unpin_all_weights()
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
move_weight_functions(m, device_to)
|
move_weight_functions(m, device_to)
|
||||||
@ -825,10 +887,12 @@ class ModelPatcher:
|
|||||||
module_mem += move_weight_functions(m, device_to)
|
module_mem += move_weight_functions(m, device_to)
|
||||||
if lowvram_possible:
|
if lowvram_possible:
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function.append(LowVramPatch(weight_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, weight_key)
|
||||||
|
m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function.append(LowVramPatch(bias_key, self.patches))
|
_, set_func, convert_func = get_key_weight(self.model, bias_key)
|
||||||
|
m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func))
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
cast_weight = True
|
cast_weight = True
|
||||||
|
|
||||||
@ -839,6 +903,9 @@ class ModelPatcher:
|
|||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
for param in params:
|
||||||
|
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.model_loaded_weight_memory -= memory_freed
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
@ -1241,5 +1308,6 @@ class ModelPatcher:
|
|||||||
self.clear_cached_hook_weights()
|
self.clear_cached_hook_weights()
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
self.unpin_all_weights()
|
||||||
self.detach(unpatch_all=False)
|
self.detach(unpatch_all=False)
|
||||||
|
|
||||||
|
|||||||
@ -21,17 +21,23 @@ def rescale_zero_terminal_snr_sigmas(sigmas):
|
|||||||
alphas_bar[-1] = 4.8973451890853435e-08
|
alphas_bar[-1] = 4.8973451890853435e-08
|
||||||
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
return ((1 - alphas_bar) / alphas_bar) ** 0.5
|
||||||
|
|
||||||
|
def reshape_sigma(sigma, noise_dim):
|
||||||
|
if sigma.nelement() == 1:
|
||||||
|
return sigma.view(())
|
||||||
|
else:
|
||||||
|
return sigma.view(sigma.shape[:1] + (1,) * (noise_dim - 1))
|
||||||
|
|
||||||
class EPS:
|
class EPS:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return noise / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||||
return model_input - model_output * sigma
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
if max_denoise:
|
if max_denoise:
|
||||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||||
else:
|
else:
|
||||||
@ -45,12 +51,12 @@ class EPS:
|
|||||||
|
|
||||||
class V_PREDICTION(EPS):
|
class V_PREDICTION(EPS):
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
class EDM(V_PREDICTION):
|
class EDM(V_PREDICTION):
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||||
|
|
||||||
class CONST:
|
class CONST:
|
||||||
@ -58,15 +64,15 @@ class CONST:
|
|||||||
return noise
|
return noise
|
||||||
|
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||||
return model_input - model_output * sigma
|
return model_input - model_output * sigma
|
||||||
|
|
||||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
return sigma * noise + (1.0 - sigma) * latent_image
|
return sigma * noise + (1.0 - sigma) * latent_image
|
||||||
|
|
||||||
def inverse_noise_scaling(self, sigma, latent):
|
def inverse_noise_scaling(self, sigma, latent):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
|
sigma = reshape_sigma(sigma, latent.ndim)
|
||||||
return latent / (1.0 - sigma)
|
return latent / (1.0 - sigma)
|
||||||
|
|
||||||
class X0(EPS):
|
class X0(EPS):
|
||||||
@ -80,16 +86,16 @@ class IMG_TO_IMG(X0):
|
|||||||
class COSMOS_RFLOW:
|
class COSMOS_RFLOW:
|
||||||
def calculate_input(self, sigma, noise):
|
def calculate_input(self, sigma, noise):
|
||||||
sigma = (sigma / (sigma + 1))
|
sigma = (sigma / (sigma + 1))
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
return noise * (1.0 - sigma)
|
return noise * (1.0 - sigma)
|
||||||
|
|
||||||
def calculate_denoised(self, sigma, model_output, model_input):
|
def calculate_denoised(self, sigma, model_output, model_input):
|
||||||
sigma = (sigma / (sigma + 1))
|
sigma = (sigma / (sigma + 1))
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||||
return model_input * (1.0 - sigma) - model_output * sigma
|
return model_input * (1.0 - sigma) - model_output * sigma
|
||||||
|
|
||||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
|
sigma = reshape_sigma(sigma, noise.ndim)
|
||||||
noise = noise * sigma
|
noise = noise * sigma
|
||||||
noise += latent_image
|
noise += latent_image
|
||||||
return noise
|
return noise
|
||||||
|
|||||||
91
comfy/nested_tensor.py
Normal file
91
comfy/nested_tensor.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class NestedTensor:
|
||||||
|
def __init__(self, tensors):
|
||||||
|
self.tensors = list(tensors)
|
||||||
|
self.is_nested = True
|
||||||
|
|
||||||
|
def _copy(self):
|
||||||
|
return NestedTensor(self.tensors)
|
||||||
|
|
||||||
|
def apply_operation(self, other, operation):
|
||||||
|
o = self._copy()
|
||||||
|
if isinstance(other, NestedTensor):
|
||||||
|
for i, t in enumerate(o.tensors):
|
||||||
|
o.tensors[i] = operation(t, other.tensors[i])
|
||||||
|
else:
|
||||||
|
for i, t in enumerate(o.tensors):
|
||||||
|
o.tensors[i] = operation(t, other)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def __add__(self, b):
|
||||||
|
return self.apply_operation(b, lambda x, y: x + y)
|
||||||
|
|
||||||
|
def __sub__(self, b):
|
||||||
|
return self.apply_operation(b, lambda x, y: x - y)
|
||||||
|
|
||||||
|
def __mul__(self, b):
|
||||||
|
return self.apply_operation(b, lambda x, y: x * y)
|
||||||
|
|
||||||
|
# def __itruediv__(self, b):
|
||||||
|
# return self.apply_operation(b, lambda x, y: x / y)
|
||||||
|
|
||||||
|
def __truediv__(self, b):
|
||||||
|
return self.apply_operation(b, lambda x, y: x / y)
|
||||||
|
|
||||||
|
def __getitem__(self, *args, **kwargs):
|
||||||
|
return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs))
|
||||||
|
|
||||||
|
def unbind(self):
|
||||||
|
return self.tensors
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
o = self._copy()
|
||||||
|
for i, t in enumerate(o.tensors):
|
||||||
|
o.tensors[i] = t.to(*args, **kwargs)
|
||||||
|
return o
|
||||||
|
|
||||||
|
def new_ones(self, *args, **kwargs):
|
||||||
|
return self.tensors[0].new_ones(*args, **kwargs)
|
||||||
|
|
||||||
|
def float(self):
|
||||||
|
return self.to(dtype=torch.float)
|
||||||
|
|
||||||
|
def chunk(self, *args, **kwargs):
|
||||||
|
return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs))
|
||||||
|
|
||||||
|
def size(self):
|
||||||
|
return self.tensors[0].size()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self.tensors[0].shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ndim(self):
|
||||||
|
dims = 0
|
||||||
|
for t in self.tensors:
|
||||||
|
dims = max(t.ndim, dims)
|
||||||
|
return dims
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.tensors[0].device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.tensors[0].dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout(self):
|
||||||
|
return self.tensors[0].layout
|
||||||
|
|
||||||
|
|
||||||
|
def cat_nested(tensors, *args, **kwargs):
|
||||||
|
cated_tensors = []
|
||||||
|
for i in range(len(tensors[0].tensors)):
|
||||||
|
tens = []
|
||||||
|
for j in range(len(tensors)):
|
||||||
|
tens.append(tensors[j].tensors[i])
|
||||||
|
cated_tensors.append(torch.cat(tens, *args, **kwargs))
|
||||||
|
return NestedTensor(cated_tensors)
|
||||||
339
comfy/ops.py
339
comfy/ops.py
@ -24,13 +24,18 @@ import comfy.float
|
|||||||
import comfy.rmsnorm
|
import comfy.rmsnorm
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
|
def run_every_op():
|
||||||
|
if torch.compiler.is_compiling():
|
||||||
|
return
|
||||||
|
|
||||||
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
|
|
||||||
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available() and comfy.model_management.WINDOWS:
|
||||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||||
import inspect
|
import inspect
|
||||||
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
if "set_priority" in inspect.signature(sdpa_kernel).parameters:
|
||||||
@ -50,15 +55,26 @@ try:
|
|||||||
except (ModuleNotFoundError, TypeError):
|
except (ModuleNotFoundError, TypeError):
|
||||||
logging.warning("Could not set sdpa backend priority.")
|
logging.warning("Could not set sdpa backend priority.")
|
||||||
|
|
||||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||||
|
try:
|
||||||
|
if comfy.model_management.is_nvidia():
|
||||||
|
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||||
|
#TODO: change upper bound version once it's fixed'
|
||||||
|
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||||
|
logging.info("working around nvidia conv3d memory bug.")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|
||||||
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False):
|
||||||
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
|
# will add async-offload support to your cast and improve performance.
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
@ -67,32 +83,58 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
if offloadable and (device != s.weight.device or
|
||||||
|
(s.bias is not None and device != s.bias.device)):
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
else:
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
if offload_stream is not None:
|
if offload_stream is not None:
|
||||||
wf_context = offload_stream
|
wf_context = offload_stream
|
||||||
else:
|
else:
|
||||||
wf_context = contextlib.nullcontext()
|
wf_context = contextlib.nullcontext()
|
||||||
|
|
||||||
bias = None
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
|
||||||
has_function = len(s.bias_function) > 0
|
|
||||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
|
||||||
|
|
||||||
if has_function:
|
weight_has_function = len(s.weight_function) > 0
|
||||||
|
bias_has_function = len(s.bias_function) > 0
|
||||||
|
|
||||||
|
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
bias = None
|
||||||
|
if s.bias is not None:
|
||||||
|
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||||
|
|
||||||
|
if bias_has_function:
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.bias_function:
|
for f in s.bias_function:
|
||||||
bias = f(bias)
|
bias = f(bias)
|
||||||
|
|
||||||
has_function = len(s.weight_function) > 0
|
weight = weight.to(dtype=dtype)
|
||||||
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
|
if weight_has_function:
|
||||||
if has_function:
|
|
||||||
with wf_context:
|
with wf_context:
|
||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
return weight, bias
|
if offloadable:
|
||||||
|
return weight, bias, offload_stream
|
||||||
|
else:
|
||||||
|
#Legacy function signature
|
||||||
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
if offload_stream is None:
|
||||||
|
return
|
||||||
|
if weight is not None:
|
||||||
|
device = weight.device
|
||||||
|
else:
|
||||||
|
if bias is None:
|
||||||
|
return
|
||||||
|
device = bias.device
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
|
||||||
|
|
||||||
class CastWeightBiasOp:
|
class CastWeightBiasOp:
|
||||||
comfy_cast_weights = False
|
comfy_cast_weights = False
|
||||||
@ -105,10 +147,13 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -119,10 +164,13 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -133,10 +181,13 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -146,11 +197,23 @@ class disable_weight_init:
|
|||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _conv_forward(self, input, weight, bias, *args, **kwargs):
|
||||||
|
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
|
||||||
|
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
|
||||||
|
if bias is not None:
|
||||||
|
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return self._conv_forward(input, weight, bias)
|
x = self._conv_forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -161,10 +224,13 @@ class disable_weight_init:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -176,13 +242,17 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
offload_stream = None
|
||||||
|
x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -195,13 +265,18 @@ class disable_weight_init:
|
|||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
if self.weight is not None:
|
if self.weight is not None:
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
else:
|
else:
|
||||||
weight = None
|
weight = None
|
||||||
return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
bias = None
|
||||||
# return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
offload_stream = None
|
||||||
|
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
||||||
|
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -217,12 +292,15 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose2d(
|
x = torch.nn.functional.conv_transpose2d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -238,12 +316,15 @@ class disable_weight_init:
|
|||||||
input, output_size, self.stride, self.padding, self.kernel_size,
|
input, output_size, self.stride, self.padding, self.kernel_size,
|
||||||
num_spatial_dims, self.dilation)
|
num_spatial_dims, self.dilation)
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.conv_transpose1d(
|
x = torch.nn.functional.conv_transpose1d(
|
||||||
input, weight, bias, self.stride, self.padding,
|
input, weight, bias, self.stride, self.padding,
|
||||||
output_padding, self.groups, self.dilation)
|
output_padding, self.groups, self.dilation)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -258,10 +339,14 @@ class disable_weight_init:
|
|||||||
output_dtype = out_dtype
|
output_dtype = out_dtype
|
||||||
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
|
||||||
out_dtype = None
|
out_dtype = None
|
||||||
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True)
|
||||||
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
return self.forward_comfy_cast_weights(*args, **kwargs)
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
@ -312,20 +397,18 @@ class manual_cast(disable_weight_init):
|
|||||||
|
|
||||||
|
|
||||||
def fp8_linear(self, input):
|
def fp8_linear(self, input):
|
||||||
|
"""
|
||||||
|
Legacy FP8 linear function for backward compatibility.
|
||||||
|
Uses QuantizedTensor subclass for dispatch.
|
||||||
|
"""
|
||||||
dtype = self.weight.dtype
|
dtype = self.weight.dtype
|
||||||
if dtype not in [torch.float8_e4m3fn]:
|
if dtype not in [torch.float8_e4m3fn]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
tensor_2d = False
|
|
||||||
if len(input.shape) == 2:
|
|
||||||
tensor_2d = True
|
|
||||||
input = input.unsqueeze(1)
|
|
||||||
|
|
||||||
input_shape = input.shape
|
|
||||||
input_dtype = input.dtype
|
input_dtype = input.dtype
|
||||||
if len(input.shape) == 3:
|
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
if input.ndim == 3 or input.ndim == 2:
|
||||||
w = w.t()
|
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@ -337,23 +420,20 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
|
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||||
|
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
|
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||||
|
|
||||||
if bias is not None:
|
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||||
else:
|
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
|
||||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
|
||||||
|
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
uncast_bias_weight(self, w, bias, offload_stream)
|
||||||
o = o[0]
|
return o
|
||||||
|
|
||||||
if tensor_2d:
|
|
||||||
return o.reshape(input_shape[0], -1)
|
|
||||||
|
|
||||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -373,8 +453,10 @@ class fp8_ops(manual_cast):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info("Exception during fp8 op: {}".format(e))
|
logging.info("Exception during fp8 op: {}".format(e))
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
x = torch.nn.functional.linear(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||||
@ -402,12 +484,14 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
if out is not None:
|
if out is not None:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
weight, bias = cast_bias_weight(self, input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
|
||||||
if weight.numel() < input.numel(): #TODO: optimize
|
if weight.numel() < input.numel(): #TODO: optimize
|
||||||
return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||||
if inplace:
|
if inplace:
|
||||||
@ -416,8 +500,10 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
|
|||||||
else:
|
else:
|
||||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||||
|
|
||||||
def set_weight(self, weight, inplace_update=False, seed=None, **kwargs):
|
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||||
|
if return_weight:
|
||||||
|
return weight
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
self.weight.data.copy_(weight)
|
self.weight.data.copy_(weight)
|
||||||
else:
|
else:
|
||||||
@ -444,7 +530,130 @@ if CUBLAS_IS_AVAILABLE:
|
|||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Mixed Precision Operations
|
||||||
|
# ==============================================================================
|
||||||
|
from .quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
QUANT_FORMAT_MIXINS = {
|
||||||
|
"float8_e4m3fn": {
|
||||||
|
"dtype": torch.float8_e4m3fn,
|
||||||
|
"layout_type": "TensorCoreFP8Layout",
|
||||||
|
"parameters": {
|
||||||
|
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MixedPrecisionOps(disable_weight_init):
|
||||||
|
_layer_quant_config = {}
|
||||||
|
_compute_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
bias: bool = True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||||
|
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
if bias:
|
||||||
|
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
|
self.tensor_class = None
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
device = self.factory_kwargs["device"]
|
||||||
|
layer_name = prefix.rstrip('.')
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
weight = state_dict.pop(weight_key, None)
|
||||||
|
if weight is None:
|
||||||
|
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||||
|
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||||
|
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||||
|
else:
|
||||||
|
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||||
|
if quant_format is None:
|
||||||
|
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||||
|
|
||||||
|
mixin = QUANT_FORMAT_MIXINS[quant_format]
|
||||||
|
self.layout_type = mixin["layout_type"]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
layout_params = {
|
||||||
|
'scale': state_dict.pop(scale_key, None),
|
||||||
|
'orig_dtype': MixedPrecisionOps._compute_dtype
|
||||||
|
}
|
||||||
|
if layout_params['scale'] is not None:
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(device=device, dtype=mixin["dtype"]), self.layout_type, layout_params),
|
||||||
|
requires_grad=False
|
||||||
|
)
|
||||||
|
|
||||||
|
for param_name, param_value in mixin["parameters"].items():
|
||||||
|
param_key = f"{prefix}{param_name}"
|
||||||
|
_v = state_dict.pop(param_key, None)
|
||||||
|
if _v is None:
|
||||||
|
continue
|
||||||
|
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||||
|
manually_loaded_keys.append(param_key)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
for key in manually_loaded_keys:
|
||||||
|
if key in missing_keys:
|
||||||
|
missing_keys.remove(key)
|
||||||
|
|
||||||
|
def _forward(self, input, weight, bias):
|
||||||
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
x = self._forward(input, weight, bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
|
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||||
|
if (getattr(self, 'layout_type', None) is not None and
|
||||||
|
getattr(self, 'input_scale', None) is not None and
|
||||||
|
not isinstance(input, QuantizedTensor)):
|
||||||
|
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, fp8_dtype=self.weight.dtype)
|
||||||
|
return self._forward(input, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||||
|
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||||
|
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||||
|
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||||
|
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||||
|
return MixedPrecisionOps
|
||||||
|
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||||
if scaled_fp8 is not None:
|
if scaled_fp8 is not None:
|
||||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||||
|
|||||||
@ -150,7 +150,7 @@ def merge_nested_dicts(dict1: dict, dict2: dict, copy_dict1=True):
|
|||||||
for key, value in dict2.items():
|
for key, value in dict2.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
curr_value = merged_dict.setdefault(key, {})
|
curr_value = merged_dict.setdefault(key, {})
|
||||||
merged_dict[key] = merge_nested_dicts(value, curr_value)
|
merged_dict[key] = merge_nested_dicts(curr_value, value)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
merged_dict.setdefault(key, []).extend(value)
|
merged_dict.setdefault(key, []).extend(value)
|
||||||
else:
|
else:
|
||||||
|
|||||||
512
comfy/quant_ops.py
Normal file
512
comfy/quant_ops.py
Normal file
@ -0,0 +1,512 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
from typing import Tuple, Dict
|
||||||
|
|
||||||
|
_LAYOUT_REGISTRY = {}
|
||||||
|
_GENERIC_UTILS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_layout_op(torch_op, layout_type):
|
||||||
|
"""
|
||||||
|
Decorator to register a layout-specific operation handler.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default)
|
||||||
|
layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
Example:
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout)
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
# FP8-specific linear implementation
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
if torch_op not in _LAYOUT_REGISTRY:
|
||||||
|
_LAYOUT_REGISTRY[torch_op] = {}
|
||||||
|
_LAYOUT_REGISTRY[torch_op][layout_type] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_generic_util(torch_op):
|
||||||
|
"""
|
||||||
|
Decorator to register a generic utility that works for all layouts.
|
||||||
|
Args:
|
||||||
|
torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
# Works for any layout
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(handler_func):
|
||||||
|
_GENERIC_UTILS[torch_op] = handler_func
|
||||||
|
return handler_func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _get_layout_from_args(args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg._layout_type
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
for item in arg:
|
||||||
|
if isinstance(item, QuantizedTensor):
|
||||||
|
return item._layout_type
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _move_layout_params_to_device(params, device):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.to(device=device)
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_layout_params(params):
|
||||||
|
new_params = {}
|
||||||
|
for k, v in params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_params[k] = v.clone()
|
||||||
|
else:
|
||||||
|
new_params[k] = v
|
||||||
|
return new_params
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedLayout:
|
||||||
|
"""
|
||||||
|
Base class for quantization layouts.
|
||||||
|
|
||||||
|
A layout encapsulates the format-specific logic for quantization/dequantization
|
||||||
|
and provides a uniform interface for extracting raw tensors needed for computation.
|
||||||
|
|
||||||
|
New quantization formats should subclass this and implement the required methods.
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement quantize()")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, **layout_params) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("TensorLayout must implement dequantize()")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor) -> torch.Tensor:
|
||||||
|
raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()")
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizedTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
Universal quantized tensor that works with any layout.
|
||||||
|
|
||||||
|
This tensor subclass uses a pluggable layout system to support multiple
|
||||||
|
quantization formats (FP8, INT4, INT8, etc.) without code duplication.
|
||||||
|
|
||||||
|
The layout_type determines format-specific behavior, while common operations
|
||||||
|
(detach, clone, to) are handled generically.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_qdata: The quantized tensor data
|
||||||
|
_layout_type: Layout class (e.g., TensorCoreFP8Layout)
|
||||||
|
_layout_params: Dict with layout-specific params (scale, zero_point, etc.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, qdata, layout_type, layout_params):
|
||||||
|
"""
|
||||||
|
Create a quantized tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qdata: The quantized data tensor
|
||||||
|
layout_type: Layout class (subclass of QuantizedLayout)
|
||||||
|
layout_params: Dict with layout-specific parameters
|
||||||
|
"""
|
||||||
|
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
|
||||||
|
|
||||||
|
def __init__(self, qdata, layout_type, layout_params):
|
||||||
|
self._qdata = qdata
|
||||||
|
self._layout_type = layout_type
|
||||||
|
self._layout_params = layout_params
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
layout_name = self._layout_type
|
||||||
|
param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2])
|
||||||
|
return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def layout_type(self):
|
||||||
|
return self._layout_type
|
||||||
|
|
||||||
|
def __tensor_flatten__(self):
|
||||||
|
"""
|
||||||
|
Tensor flattening protocol for proper device movement.
|
||||||
|
"""
|
||||||
|
inner_tensors = ["_qdata"]
|
||||||
|
ctx = {
|
||||||
|
"layout_type": self._layout_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor_params = {}
|
||||||
|
non_tensor_params = {}
|
||||||
|
for k, v in self._layout_params.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
tensor_params[k] = v
|
||||||
|
else:
|
||||||
|
non_tensor_params[k] = v
|
||||||
|
|
||||||
|
ctx["tensor_param_keys"] = list(tensor_params.keys())
|
||||||
|
ctx["non_tensor_params"] = non_tensor_params
|
||||||
|
|
||||||
|
for k, v in tensor_params.items():
|
||||||
|
attr_name = f"_layout_param_{k}"
|
||||||
|
object.__setattr__(self, attr_name, v)
|
||||||
|
inner_tensors.append(attr_name)
|
||||||
|
|
||||||
|
return inner_tensors, ctx
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
|
||||||
|
"""
|
||||||
|
Tensor unflattening protocol for proper device movement.
|
||||||
|
Reconstructs the QuantizedTensor after device movement.
|
||||||
|
"""
|
||||||
|
layout_type = ctx["layout_type"]
|
||||||
|
layout_params = dict(ctx["non_tensor_params"])
|
||||||
|
|
||||||
|
for key in ctx["tensor_param_keys"]:
|
||||||
|
attr_name = f"_layout_param_{key}"
|
||||||
|
layout_params[key] = inner_tensors[attr_name]
|
||||||
|
|
||||||
|
return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
|
||||||
|
qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
|
||||||
|
return cls(qdata, layout_type, layout_params)
|
||||||
|
|
||||||
|
def dequantize(self) -> torch.Tensor:
|
||||||
|
return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
# Step 1: Check generic utilities first (detach, clone, to, etc.)
|
||||||
|
if func in _GENERIC_UTILS:
|
||||||
|
return _GENERIC_UTILS[func](func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 2: Check layout-specific handlers (linear, matmul, etc.)
|
||||||
|
layout_type = _get_layout_from_args(args)
|
||||||
|
if layout_type and func in _LAYOUT_REGISTRY:
|
||||||
|
handler = _LAYOUT_REGISTRY[func].get(layout_type)
|
||||||
|
if handler:
|
||||||
|
return handler(func, args, kwargs)
|
||||||
|
|
||||||
|
# Step 3: Fallback to dequantization
|
||||||
|
if isinstance(args[0] if args else None, QuantizedTensor):
|
||||||
|
logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}")
|
||||||
|
return cls._dequant_and_fallback(func, args, kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _dequant_and_fallback(cls, func, args, kwargs):
|
||||||
|
def dequant_arg(arg):
|
||||||
|
if isinstance(arg, QuantizedTensor):
|
||||||
|
return arg.dequantize()
|
||||||
|
elif isinstance(arg, (list, tuple)):
|
||||||
|
return type(arg)(dequant_arg(a) for a in arg)
|
||||||
|
return arg
|
||||||
|
|
||||||
|
new_args = dequant_arg(args)
|
||||||
|
new_kwargs = dequant_arg(kwargs)
|
||||||
|
return func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Generic Utilities (Layout-Agnostic Operations)
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
def _create_transformed_qtensor(qt, transform_fn):
|
||||||
|
new_data = transform_fn(qt._qdata)
|
||||||
|
new_params = _copy_layout_params(qt._layout_params)
|
||||||
|
return QuantizedTensor(new_data, qt._layout_type, new_params)
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||||
|
if target_dtype is not None and target_dtype != qt.dtype:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||||
|
f"but not supported for quantized tensors. Ignoring dtype."
|
||||||
|
)
|
||||||
|
|
||||||
|
if target_layout is not None and target_layout != torch.strided:
|
||||||
|
logging.warning(
|
||||||
|
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||||
|
f"but not supported. Ignoring layout."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle device transfer
|
||||||
|
current_device = qt._qdata.device
|
||||||
|
if target_device is not None:
|
||||||
|
# Normalize device for comparison
|
||||||
|
if isinstance(target_device, str):
|
||||||
|
target_device = torch.device(target_device)
|
||||||
|
if isinstance(current_device, str):
|
||||||
|
current_device = torch.device(current_device)
|
||||||
|
|
||||||
|
if target_device != current_device:
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||||
|
new_q_data = qt._qdata.to(device=target_device)
|
||||||
|
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||||
|
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||||
|
return new_qt
|
||||||
|
|
||||||
|
logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original")
|
||||||
|
return qt
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.detach.default)
|
||||||
|
def generic_detach(func, args, kwargs):
|
||||||
|
"""Detach operation - creates a detached copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.detach())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.clone.default)
|
||||||
|
def generic_clone(func, args, kwargs):
|
||||||
|
"""Clone operation - creates a deep copy of the quantized tensor."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _create_transformed_qtensor(qt, lambda x: x.clone())
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._to_copy.default)
|
||||||
|
def generic_to_copy(func, args, kwargs):
|
||||||
|
"""Device/dtype transfer operation - handles .to(device) calls."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
op_name="_to_copy"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.to.dtype_layout)
|
||||||
|
def generic_to_dtype_layout(func, args, kwargs):
|
||||||
|
"""Handle .to(device) calls using the dtype_layout variant."""
|
||||||
|
qt = args[0]
|
||||||
|
if isinstance(qt, QuantizedTensor):
|
||||||
|
return _handle_device_transfer(
|
||||||
|
qt,
|
||||||
|
target_device=kwargs.get('device', None),
|
||||||
|
target_dtype=kwargs.get('dtype', None),
|
||||||
|
target_layout=kwargs.get('layout', None),
|
||||||
|
op_name="to"
|
||||||
|
)
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten.copy_.default)
|
||||||
|
def generic_copy_(func, args, kwargs):
|
||||||
|
qt_dest = args[0]
|
||||||
|
src = args[1]
|
||||||
|
|
||||||
|
if isinstance(qt_dest, QuantizedTensor):
|
||||||
|
if isinstance(src, QuantizedTensor):
|
||||||
|
# Copy from another quantized tensor
|
||||||
|
qt_dest._qdata.copy_(src._qdata)
|
||||||
|
qt_dest._layout_type = src._layout_type
|
||||||
|
qt_dest._layout_params = _copy_layout_params(src._layout_params)
|
||||||
|
else:
|
||||||
|
# Copy from regular tensor - just copy raw data
|
||||||
|
qt_dest._qdata.copy_(src)
|
||||||
|
return qt_dest
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||||
|
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# FP8 Layout + Operation Handlers
|
||||||
|
# ==============================================================================
|
||||||
|
class TensorCoreFP8Layout(QuantizedLayout):
|
||||||
|
"""
|
||||||
|
Storage format:
|
||||||
|
- qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2)
|
||||||
|
- scale: Scalar tensor (float32) for dequantization
|
||||||
|
- orig_dtype: Original dtype before quantization (for casting back)
|
||||||
|
"""
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||||
|
|
||||||
|
if not isinstance(scale, torch.Tensor):
|
||||||
|
scale = torch.tensor(scale)
|
||||||
|
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||||
|
|
||||||
|
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||||
|
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||||
|
# lp_amax = torch.finfo(dtype).max
|
||||||
|
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||||
|
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||||
|
|
||||||
|
layout_params = {
|
||||||
|
'scale': scale,
|
||||||
|
'orig_dtype': orig_dtype
|
||||||
|
}
|
||||||
|
return qdata, layout_params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||||
|
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||||
|
return plain_tensor * scale
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_plain_tensors(cls, qtensor):
|
||||||
|
return qtensor._qdata, qtensor._layout_params['scale']
|
||||||
|
|
||||||
|
|
||||||
|
LAYOUTS = {
|
||||||
|
"TensorCoreFP8Layout": TensorCoreFP8Layout,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_linear(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
bias = args[2] if len(args) > 2 else None
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
out_dtype = kwargs.get("out_dtype")
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
weight_t = plain_weight.t()
|
||||||
|
|
||||||
|
tensor_2d = False
|
||||||
|
if len(plain_input.shape) == 2:
|
||||||
|
tensor_2d = True
|
||||||
|
plain_input = plain_input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = plain_input.shape
|
||||||
|
if len(input_shape) != 3:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.reshape(-1, input_shape[2]).contiguous(),
|
||||||
|
weight_t,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||||
|
output = output[0]
|
||||||
|
|
||||||
|
if not tensor_2d:
|
||||||
|
output = output.reshape((-1, input_shape[1], weight.shape[0]))
|
||||||
|
|
||||||
|
if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
||||||
|
output_scale = scale_a * scale_b
|
||||||
|
output_params = {
|
||||||
|
'scale': output_scale,
|
||||||
|
'orig_dtype': input_tensor._layout_params['orig_dtype']
|
||||||
|
}
|
||||||
|
return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
|
||||||
|
else:
|
||||||
|
return output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}")
|
||||||
|
|
||||||
|
# Case 2: DQ Fallback
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
input_tensor = input_tensor.dequantize()
|
||||||
|
|
||||||
|
return torch.nn.functional.linear(input_tensor, weight, bias)
|
||||||
|
|
||||||
|
def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None):
|
||||||
|
if out_dtype is None:
|
||||||
|
out_dtype = input_tensor._layout_params['orig_dtype']
|
||||||
|
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight)
|
||||||
|
|
||||||
|
output = torch._scaled_mm(
|
||||||
|
plain_input.contiguous(),
|
||||||
|
plain_weight,
|
||||||
|
bias=bias,
|
||||||
|
scale_a=scale_a,
|
||||||
|
scale_b=scale_b,
|
||||||
|
out_dtype=out_dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4
|
||||||
|
output = output[0]
|
||||||
|
return output
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_addmm(func, args, kwargs):
|
||||||
|
input_tensor = args[1]
|
||||||
|
weight = args[2]
|
||||||
|
bias = args[0]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
if isinstance(args[2], QuantizedTensor):
|
||||||
|
a[2] = args[2].dequantize()
|
||||||
|
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_mm(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
weight = args[1]
|
||||||
|
|
||||||
|
if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor):
|
||||||
|
return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None))
|
||||||
|
|
||||||
|
a = list(args)
|
||||||
|
if isinstance(args[0], QuantizedTensor):
|
||||||
|
a[0] = args[0].dequantize()
|
||||||
|
if isinstance(args[1], QuantizedTensor):
|
||||||
|
a[1] = args[1].dequantize()
|
||||||
|
return func(*a, **kwargs)
|
||||||
|
|
||||||
|
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
|
||||||
|
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
|
||||||
|
def fp8_func(func, args, kwargs):
|
||||||
|
input_tensor = args[0]
|
||||||
|
if isinstance(input_tensor, QuantizedTensor):
|
||||||
|
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
|
||||||
|
ar = list(args)
|
||||||
|
ar[0] = plain_input
|
||||||
|
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
|
||||||
|
return func(*args, **kwargs)
|
||||||
@ -4,13 +4,9 @@ import comfy.samplers
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
|
import comfy.nested_tensor
|
||||||
|
|
||||||
def prepare_noise(latent_image, seed, noise_inds=None):
|
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||||
"""
|
|
||||||
creates random noise given a latent image and a seed.
|
|
||||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
|
||||||
"""
|
|
||||||
generator = torch.manual_seed(seed)
|
|
||||||
if noise_inds is None:
|
if noise_inds is None:
|
||||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||||
|
|
||||||
@ -21,10 +17,29 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
|||||||
if i in unique_inds:
|
if i in unique_inds:
|
||||||
noises.append(noise)
|
noises.append(noise)
|
||||||
noises = [noises[i] for i in inverse]
|
noises = [noises[i] for i in inverse]
|
||||||
noises = torch.cat(noises, axis=0)
|
return torch.cat(noises, axis=0)
|
||||||
|
|
||||||
|
def prepare_noise(latent_image, seed, noise_inds=None):
|
||||||
|
"""
|
||||||
|
creates random noise given a latent image and a seed.
|
||||||
|
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||||
|
"""
|
||||||
|
generator = torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if latent_image.is_nested:
|
||||||
|
tensors = latent_image.unbind()
|
||||||
|
noises = []
|
||||||
|
for t in tensors:
|
||||||
|
noises.append(prepare_noise_inner(t, generator, noise_inds))
|
||||||
|
noises = comfy.nested_tensor.NestedTensor(noises)
|
||||||
|
else:
|
||||||
|
noises = prepare_noise_inner(latent_image, generator, noise_inds)
|
||||||
|
|
||||||
return noises
|
return noises
|
||||||
|
|
||||||
def fix_empty_latent_channels(model, latent_image):
|
def fix_empty_latent_channels(model, latent_image):
|
||||||
|
if latent_image.is_nested:
|
||||||
|
return latent_image
|
||||||
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
|
||||||
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
|
||||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||||
|
|||||||
@ -306,17 +306,10 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens
|
|||||||
copy_dict1=False)
|
copy_dict1=False)
|
||||||
|
|
||||||
if patches is not None:
|
if patches is not None:
|
||||||
# TODO: replace with merge_nested_dicts function
|
transformer_options["patches"] = comfy.patcher_extension.merge_nested_dicts(
|
||||||
if "patches" in transformer_options:
|
transformer_options.get("patches", {}),
|
||||||
cur_patches = transformer_options["patches"].copy()
|
patches
|
||||||
for p in patches:
|
)
|
||||||
if p in cur_patches:
|
|
||||||
cur_patches[p] = cur_patches[p] + patches[p]
|
|
||||||
else:
|
|
||||||
cur_patches[p] = patches[p]
|
|
||||||
transformer_options["patches"] = cur_patches
|
|
||||||
else:
|
|
||||||
transformer_options["patches"] = patches
|
|
||||||
|
|
||||||
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
transformer_options["cond_or_uncond"] = cond_or_uncond[:]
|
||||||
transformer_options["uuids"] = uuids[:]
|
transformer_options["uuids"] = uuids[:]
|
||||||
@ -789,7 +782,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}):
|
|||||||
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
return KSAMPLER(sampler_function, extra_options, inpaint_options)
|
||||||
|
|
||||||
|
|
||||||
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None):
|
def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
conds[k] = conds[k][:]
|
conds[k] = conds[k][:]
|
||||||
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device)
|
||||||
@ -799,7 +792,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N
|
|||||||
|
|
||||||
if hasattr(model, 'extra_conds'):
|
if hasattr(model, 'extra_conds'):
|
||||||
for k in conds:
|
for k in conds:
|
||||||
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed)
|
conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes)
|
||||||
|
|
||||||
#make sure each cond area has an opposite one with the same area
|
#make sure each cond area has an opposite one with the same area
|
||||||
for k in conds:
|
for k in conds:
|
||||||
@ -969,11 +962,11 @@ class CFGGuider:
|
|||||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||||
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed)
|
||||||
|
|
||||||
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed):
|
def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None):
|
||||||
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0: #Don't shift the empty latent image.
|
||||||
latent_image = self.inner_model.process_latent_in(latent_image)
|
latent_image = self.inner_model.process_latent_in(latent_image)
|
||||||
|
|
||||||
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed)
|
self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes)
|
||||||
|
|
||||||
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
extra_model_options = comfy.model_patcher.create_model_options_clone(self.model_options)
|
||||||
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas
|
||||||
@ -987,7 +980,7 @@ class CFGGuider:
|
|||||||
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||||
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
return self.inner_model.process_latent_out(samples.to(torch.float32))
|
||||||
|
|
||||||
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None):
|
||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
@ -1001,7 +994,7 @@ class CFGGuider:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
self.model_patcher.pre_run()
|
self.model_patcher.pre_run()
|
||||||
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
finally:
|
finally:
|
||||||
self.model_patcher.cleanup()
|
self.model_patcher.cleanup()
|
||||||
|
|
||||||
@ -1014,6 +1007,12 @@ class CFGGuider:
|
|||||||
if sigmas.shape[-1] == 0:
|
if sigmas.shape[-1] == 0:
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
if latent_image.is_nested:
|
||||||
|
latent_image, latent_shapes = comfy.utils.pack_latents(latent_image.unbind())
|
||||||
|
noise, _ = comfy.utils.pack_latents(noise.unbind())
|
||||||
|
else:
|
||||||
|
latent_shapes = [latent_image.shape]
|
||||||
|
|
||||||
self.conds = {}
|
self.conds = {}
|
||||||
for k in self.original_conds:
|
for k in self.original_conds:
|
||||||
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k]))
|
||||||
@ -1033,7 +1032,7 @@ class CFGGuider:
|
|||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True)
|
||||||
)
|
)
|
||||||
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes)
|
||||||
finally:
|
finally:
|
||||||
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
cast_to_load_options(self.model_options, device=self.model_patcher.offload_device)
|
||||||
self.model_options = orig_model_options
|
self.model_options = orig_model_options
|
||||||
@ -1041,6 +1040,9 @@ class CFGGuider:
|
|||||||
self.model_patcher.restore_hook_patches()
|
self.model_patcher.restore_hook_patches()
|
||||||
|
|
||||||
del self.conds
|
del self.conds
|
||||||
|
|
||||||
|
if len(latent_shapes) > 1:
|
||||||
|
output = comfy.nested_tensor.NestedTensor(comfy.utils.unpack_latents(output, latent_shapes))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
60
comfy/sd.py
60
comfy/sd.py
@ -18,6 +18,7 @@ import comfy.ldm.wan.vae2_2
|
|||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import comfy.ldm.hunyuan_video.vae
|
import comfy.ldm.hunyuan_video.vae
|
||||||
|
import comfy.ldm.mmaudio.vae.autoencoder
|
||||||
import comfy.pixel_space_convert
|
import comfy.pixel_space_convert
|
||||||
import yaml
|
import yaml
|
||||||
import math
|
import math
|
||||||
@ -142,6 +143,9 @@ class CLIP:
|
|||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.patcher.get_ram_usage()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
@ -275,8 +279,13 @@ class VAE:
|
|||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
|
if model_management.is_amd():
|
||||||
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
|
VAE_KL_MEM_RATIO = 2.73
|
||||||
|
else:
|
||||||
|
VAE_KL_MEM_RATIO = 1.0
|
||||||
|
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO #These are for AutoencoderKL and need tweaking (should be lower)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) * VAE_KL_MEM_RATIO
|
||||||
self.downscale_ratio = 8
|
self.downscale_ratio = 8
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = 8
|
||||||
self.latent_channels = 4
|
self.latent_channels = 4
|
||||||
@ -287,10 +296,12 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
|
self.size = None
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
self.extra_1d_channel = None
|
self.extra_1d_channel = None
|
||||||
|
self.crop_input = True
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
@ -542,6 +553,25 @@ class VAE:
|
|||||||
self.latent_channels = 3
|
self.latent_channels = 3
|
||||||
self.latent_dim = 2
|
self.latent_dim = 2
|
||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
|
elif "vocoder.activation_post.downsample.lowpass.filter" in sd: #MMAudio VAE
|
||||||
|
sample_rate = 16000
|
||||||
|
if sample_rate == 16000:
|
||||||
|
mode = '16k'
|
||||||
|
else:
|
||||||
|
mode = '44k'
|
||||||
|
|
||||||
|
self.first_stage_model = comfy.ldm.mmaudio.vae.autoencoder.AudioAutoencoder(mode=mode)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (30 * shape[2]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (90 * shape[2] * 1411.2) * model_management.dtype_size(dtype)
|
||||||
|
self.latent_channels = 20
|
||||||
|
self.output_channels = 2
|
||||||
|
self.upscale_ratio = 512 * (44100 / sample_rate)
|
||||||
|
self.downscale_ratio = 512 * (44100 / sample_rate)
|
||||||
|
self.latent_dim = 1
|
||||||
|
self.process_output = lambda audio: audio
|
||||||
|
self.process_input = lambda audio: audio
|
||||||
|
self.working_dtypes = [torch.float32]
|
||||||
|
self.crop_input = False
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
self.first_stage_model = None
|
self.first_stage_model = None
|
||||||
@ -569,12 +599,25 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
self.model_size()
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size is not None:
|
||||||
|
return self.size
|
||||||
|
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
|
if not self.crop_input:
|
||||||
|
return pixels
|
||||||
|
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|
||||||
dims = pixels.shape[1:-1]
|
dims = pixels.shape[1:-1]
|
||||||
@ -1233,7 +1276,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_diffusion_model_state_dict(sd, model_options={}):
|
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||||
"""
|
"""
|
||||||
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
|
||||||
|
|
||||||
@ -1267,7 +1310,7 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
|
||||||
load_device = model_management.get_torch_device()
|
load_device = model_management.get_torch_device()
|
||||||
model_config = model_detection.model_config_from_unet(sd, "")
|
model_config = model_detection.model_config_from_unet(sd, "", metadata=metadata)
|
||||||
|
|
||||||
if model_config is not None:
|
if model_config is not None:
|
||||||
new_sd = sd
|
new_sd = sd
|
||||||
@ -1301,7 +1344,10 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
else:
|
else:
|
||||||
unet_dtype = dtype
|
unet_dtype = dtype
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
if model_config.layer_quant_config is not None:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
|
else:
|
||||||
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||||
if model_options.get("fp8_optimizations", False):
|
if model_options.get("fp8_optimizations", False):
|
||||||
@ -1317,8 +1363,8 @@ def load_diffusion_model_state_dict(sd, model_options={}):
|
|||||||
|
|
||||||
|
|
||||||
def load_diffusion_model(unet_path, model_options={}):
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
|
||||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BASE:
|
|||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
custom_operations = None
|
custom_operations = None
|
||||||
scaled_fp8 = None
|
scaled_fp8 = None
|
||||||
|
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||||
optimizations = {"fp8": False}
|
optimizations = {"fp8": False}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -39,7 +39,11 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
|||||||
pass
|
pass
|
||||||
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
|
||||||
|
|
||||||
from numpy.core.multiarray import scalar
|
def scalar(*args, **kwargs):
|
||||||
|
from numpy.core.multiarray import scalar as sc
|
||||||
|
return sc(*args, **kwargs)
|
||||||
|
scalar.__module__ = "numpy.core.multiarray"
|
||||||
|
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from numpy.dtypes import Float64DType
|
from numpy.dtypes import Float64DType
|
||||||
from _codecs import encode
|
from _codecs import encode
|
||||||
@ -1114,3 +1118,25 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
|||||||
dim=1
|
dim=1
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def pack_latents(latents):
|
||||||
|
latent_shapes = []
|
||||||
|
tensors = []
|
||||||
|
for tensor in latents:
|
||||||
|
latent_shapes.append(tensor.shape)
|
||||||
|
tensors.append(tensor.reshape(tensor.shape[0], 1, -1))
|
||||||
|
|
||||||
|
latent = torch.cat(tensors, dim=-1)
|
||||||
|
return latent, latent_shapes
|
||||||
|
|
||||||
|
def unpack_latents(combined_latent, latent_shapes):
|
||||||
|
if len(latent_shapes) > 1:
|
||||||
|
output_tensors = []
|
||||||
|
for shape in latent_shapes:
|
||||||
|
cut = math.prod(shape[1:])
|
||||||
|
tens = combined_latent[:, :, :cut]
|
||||||
|
combined_latent = combined_latent[:, :, cut:]
|
||||||
|
output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:]))
|
||||||
|
else:
|
||||||
|
output_tensors = combined_latent
|
||||||
|
return output_tensors
|
||||||
|
|||||||
@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
|
|||||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||||
from comfy_api.latest._io import _IO as io #noqa: F401
|
from . import _io as io
|
||||||
from comfy_api.latest._ui import _UI as ui #noqa: F401
|
from . import _ui as ui
|
||||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||||
from comfy_execution.utils import get_executing_context
|
from comfy_execution.utils import get_executing_context
|
||||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||||
@ -114,6 +114,10 @@ if TYPE_CHECKING:
|
|||||||
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
ComfyAPISync: Type[comfy_api.latest.generated.ComfyAPISyncStub.ComfyAPISyncStub]
|
||||||
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
ComfyAPISync = create_sync_class(ComfyAPI_latest)
|
||||||
|
|
||||||
|
# create new aliases for io and ui
|
||||||
|
IO = io
|
||||||
|
UI = ui
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ComfyAPI",
|
"ComfyAPI",
|
||||||
"ComfyAPISync",
|
"ComfyAPISync",
|
||||||
@ -121,4 +125,8 @@ __all__ = [
|
|||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
|
"io",
|
||||||
|
"IO",
|
||||||
|
"ui",
|
||||||
|
"UI",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, IO
|
||||||
import io
|
import io
|
||||||
import av
|
import av
|
||||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||||
@ -23,7 +23,7 @@ class VideoInput(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def save_to(
|
def save_to(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: Union[str, IO[bytes]],
|
||||||
format: VideoContainer = VideoContainer.AUTO,
|
format: VideoContainer = VideoContainer.AUTO,
|
||||||
codec: VideoCodec = VideoCodec.AUTO,
|
codec: VideoCodec = VideoCodec.AUTO,
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
|
|||||||
@ -336,11 +336,25 @@ class Combo(ComfyTypeIO):
|
|||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
"""Combo input (dropdown)."""
|
"""Combo input (dropdown)."""
|
||||||
Type = str
|
Type = str
|
||||||
def __init__(self, id: str, options: list[str]=None, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(
|
||||||
default: str=None, control_after_generate: bool=None,
|
self,
|
||||||
upload: UploadType=None, image_folder: FolderType=None,
|
id: str,
|
||||||
remote: RemoteOptions=None,
|
options: list[str] | list[int] | type[Enum] = None,
|
||||||
socketless: bool=None):
|
display_name: str=None,
|
||||||
|
optional=False,
|
||||||
|
tooltip: str=None,
|
||||||
|
lazy: bool=None,
|
||||||
|
default: str | int | Enum = None,
|
||||||
|
control_after_generate: bool=None,
|
||||||
|
upload: UploadType=None,
|
||||||
|
image_folder: FolderType=None,
|
||||||
|
remote: RemoteOptions=None,
|
||||||
|
socketless: bool=None,
|
||||||
|
):
|
||||||
|
if isinstance(options, type) and issubclass(options, Enum):
|
||||||
|
options = [v.value for v in options]
|
||||||
|
if isinstance(default, Enum):
|
||||||
|
default = default.value
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless)
|
||||||
self.multiselect = False
|
self.multiselect = False
|
||||||
self.options = options
|
self.options = options
|
||||||
@ -1568,78 +1582,78 @@ class _UIOutput(ABC):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class _IO:
|
__all__ = [
|
||||||
FolderType = FolderType
|
"FolderType",
|
||||||
UploadType = UploadType
|
"UploadType",
|
||||||
RemoteOptions = RemoteOptions
|
"RemoteOptions",
|
||||||
NumberDisplay = NumberDisplay
|
"NumberDisplay",
|
||||||
|
|
||||||
comfytype = staticmethod(comfytype)
|
"comfytype",
|
||||||
Custom = staticmethod(Custom)
|
"Custom",
|
||||||
Input = Input
|
"Input",
|
||||||
WidgetInput = WidgetInput
|
"WidgetInput",
|
||||||
Output = Output
|
"Output",
|
||||||
ComfyTypeI = ComfyTypeI
|
"ComfyTypeI",
|
||||||
ComfyTypeIO = ComfyTypeIO
|
"ComfyTypeIO",
|
||||||
#---------------------------------
|
|
||||||
# Supported Types
|
# Supported Types
|
||||||
Boolean = Boolean
|
"Boolean",
|
||||||
Int = Int
|
"Int",
|
||||||
Float = Float
|
"Float",
|
||||||
String = String
|
"String",
|
||||||
Combo = Combo
|
"Combo",
|
||||||
MultiCombo = MultiCombo
|
"MultiCombo",
|
||||||
Image = Image
|
"Image",
|
||||||
WanCameraEmbedding = WanCameraEmbedding
|
"WanCameraEmbedding",
|
||||||
Webcam = Webcam
|
"Webcam",
|
||||||
Mask = Mask
|
"Mask",
|
||||||
Latent = Latent
|
"Latent",
|
||||||
Conditioning = Conditioning
|
"Conditioning",
|
||||||
Sampler = Sampler
|
"Sampler",
|
||||||
Sigmas = Sigmas
|
"Sigmas",
|
||||||
Noise = Noise
|
"Noise",
|
||||||
Guider = Guider
|
"Guider",
|
||||||
Clip = Clip
|
"Clip",
|
||||||
ControlNet = ControlNet
|
"ControlNet",
|
||||||
Vae = Vae
|
"Vae",
|
||||||
Model = Model
|
"Model",
|
||||||
ClipVision = ClipVision
|
"ClipVision",
|
||||||
ClipVisionOutput = ClipVisionOutput
|
"ClipVisionOutput",
|
||||||
AudioEncoder = AudioEncoder
|
"AudioEncoder",
|
||||||
AudioEncoderOutput = AudioEncoderOutput
|
"AudioEncoderOutput",
|
||||||
StyleModel = StyleModel
|
"StyleModel",
|
||||||
Gligen = Gligen
|
"Gligen",
|
||||||
UpscaleModel = UpscaleModel
|
"UpscaleModel",
|
||||||
Audio = Audio
|
"Audio",
|
||||||
Video = Video
|
"Video",
|
||||||
SVG = SVG
|
"SVG",
|
||||||
LoraModel = LoraModel
|
"LoraModel",
|
||||||
LossMap = LossMap
|
"LossMap",
|
||||||
Voxel = Voxel
|
"Voxel",
|
||||||
Mesh = Mesh
|
"Mesh",
|
||||||
Hooks = Hooks
|
"Hooks",
|
||||||
HookKeyframes = HookKeyframes
|
"HookKeyframes",
|
||||||
TimestepsRange = TimestepsRange
|
"TimestepsRange",
|
||||||
LatentOperation = LatentOperation
|
"LatentOperation",
|
||||||
FlowControl = FlowControl
|
"FlowControl",
|
||||||
Accumulation = Accumulation
|
"Accumulation",
|
||||||
Load3DCamera = Load3DCamera
|
"Load3DCamera",
|
||||||
Load3D = Load3D
|
"Load3D",
|
||||||
Load3DAnimation = Load3DAnimation
|
"Load3DAnimation",
|
||||||
Photomaker = Photomaker
|
"Photomaker",
|
||||||
Point = Point
|
"Point",
|
||||||
FaceAnalysis = FaceAnalysis
|
"FaceAnalysis",
|
||||||
BBOX = BBOX
|
"BBOX",
|
||||||
SEGS = SEGS
|
"SEGS",
|
||||||
AnyType = AnyType
|
"AnyType",
|
||||||
MultiType = MultiType
|
"MultiType",
|
||||||
#---------------------------------
|
# Other classes
|
||||||
HiddenHolder = HiddenHolder
|
"HiddenHolder",
|
||||||
Hidden = Hidden
|
"Hidden",
|
||||||
NodeInfoV1 = NodeInfoV1
|
"NodeInfoV1",
|
||||||
NodeInfoV3 = NodeInfoV3
|
"NodeInfoV3",
|
||||||
Schema = Schema
|
"Schema",
|
||||||
ComfyNode = ComfyNode
|
"ComfyNode",
|
||||||
NodeOutput = NodeOutput
|
"NodeOutput",
|
||||||
add_to_dict_v1 = staticmethod(add_to_dict_v1)
|
"add_to_dict_v1",
|
||||||
add_to_dict_v3 = staticmethod(add_to_dict_v3)
|
"add_to_dict_v3",
|
||||||
|
]
|
||||||
|
|||||||
@ -449,15 +449,16 @@ class PreviewText(_UIOutput):
|
|||||||
return {"text": (self.value,)}
|
return {"text": (self.value,)}
|
||||||
|
|
||||||
|
|
||||||
class _UI:
|
__all__ = [
|
||||||
SavedResult = SavedResult
|
"SavedResult",
|
||||||
SavedImages = SavedImages
|
"SavedImages",
|
||||||
SavedAudios = SavedAudios
|
"SavedAudios",
|
||||||
ImageSaveHelper = ImageSaveHelper
|
"ImageSaveHelper",
|
||||||
AudioSaveHelper = AudioSaveHelper
|
"AudioSaveHelper",
|
||||||
PreviewImage = PreviewImage
|
"PreviewImage",
|
||||||
PreviewMask = PreviewMask
|
"PreviewMask",
|
||||||
PreviewAudio = PreviewAudio
|
"PreviewAudio",
|
||||||
PreviewVideo = PreviewVideo
|
"PreviewVideo",
|
||||||
PreviewUI3D = PreviewUI3D
|
"PreviewUI3D",
|
||||||
PreviewText = PreviewText
|
"PreviewText",
|
||||||
|
]
|
||||||
|
|||||||
@ -1,704 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import io
|
|
||||||
import logging
|
|
||||||
import mimetypes
|
|
||||||
from typing import Optional, Union
|
|
||||||
from comfy.utils import common_upscale
|
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.util import VideoContainer, VideoCodec
|
|
||||||
from comfy_api.input.video_types import VideoInput
|
|
||||||
from comfy_api.input.basic_types import AudioInput
|
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiClient,
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
UploadRequest,
|
|
||||||
UploadResponse,
|
|
||||||
)
|
|
||||||
from server import PromptServer
|
|
||||||
from comfy.cli_args import args
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import base64
|
|
||||||
import uuid
|
|
||||||
from io import BytesIO
|
|
||||||
import av
|
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_video_output(
|
|
||||||
video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
|
||||||
) -> VideoFromFile:
|
|
||||||
"""Downloads a video from a URL and returns a `VIDEO` output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_url: The URL of the video to download.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A Comfy node `VIDEO` output.
|
|
||||||
"""
|
|
||||||
video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs)
|
|
||||||
if video_io is None:
|
|
||||||
error_msg = f"Failed to download video from {video_url}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise ValueError(error_msg)
|
|
||||||
return VideoFromFile(video_io)
|
|
||||||
|
|
||||||
|
|
||||||
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
|
||||||
"""Downscale input image tensor to roughly the specified total pixels."""
|
|
||||||
samples = image.movedim(-1, 1)
|
|
||||||
total = int(total_pixels)
|
|
||||||
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
|
||||||
if scale_by >= 1:
|
|
||||||
return image
|
|
||||||
width = round(samples.shape[3] * scale_by)
|
|
||||||
height = round(samples.shape[2] * scale_by)
|
|
||||||
|
|
||||||
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
|
||||||
s = s.movedim(1, -1)
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_and_cast_response(
|
|
||||||
response, timeout: int = None, node_id: Union[str, None] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Validates and casts a response to a torch.Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The response to validate and cast.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.Tensor representing the image (1, H, W, C).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the response is not valid.
|
|
||||||
"""
|
|
||||||
# validate raw JSON response
|
|
||||||
data = response.data
|
|
||||||
if not data or len(data) == 0:
|
|
||||||
raise ValueError("No images returned from API endpoint")
|
|
||||||
|
|
||||||
# Initialize list to store image tensors
|
|
||||||
image_tensors: list[torch.Tensor] = []
|
|
||||||
|
|
||||||
# Process each image in the data array
|
|
||||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
|
||||||
for img_data in data:
|
|
||||||
img_bytes: bytes
|
|
||||||
if img_data.b64_json:
|
|
||||||
img_bytes = base64.b64decode(img_data.b64_json)
|
|
||||||
elif img_data.url:
|
|
||||||
if node_id:
|
|
||||||
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
|
|
||||||
async with session.get(img_data.url) as resp:
|
|
||||||
if resp.status != 200:
|
|
||||||
raise ValueError("Failed to download generated image")
|
|
||||||
img_bytes = await resp.read()
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid image payload – neither URL nor base64 data present.")
|
|
||||||
|
|
||||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
|
|
||||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
|
||||||
image_tensors.append(torch.from_numpy(arr))
|
|
||||||
|
|
||||||
return torch.stack(image_tensors, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio(
|
|
||||||
aspect_ratio: str,
|
|
||||||
minimum_ratio: float,
|
|
||||||
maximum_ratio: float,
|
|
||||||
minimum_ratio_str: str,
|
|
||||||
maximum_ratio_str: str,
|
|
||||||
) -> float:
|
|
||||||
"""Validates and casts an aspect ratio string to a float.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
aspect_ratio: The aspect ratio string to validate.
|
|
||||||
minimum_ratio: The minimum aspect ratio.
|
|
||||||
maximum_ratio: The maximum aspect ratio.
|
|
||||||
minimum_ratio_str: The minimum aspect ratio string.
|
|
||||||
maximum_ratio_str: The maximum aspect ratio string.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The validated and cast aspect ratio.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If the aspect ratio is not valid.
|
|
||||||
"""
|
|
||||||
# get ratio values
|
|
||||||
numbers = aspect_ratio.split(":")
|
|
||||||
if len(numbers) != 2:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
numerator = int(numbers[0])
|
|
||||||
denominator = int(numbers[1])
|
|
||||||
except ValueError as exc:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
|
|
||||||
) from exc
|
|
||||||
calculated_ratio = numerator / denominator
|
|
||||||
# if not close to minimum and maximum, check bounds
|
|
||||||
if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
|
|
||||||
calculated_ratio, maximum_ratio
|
|
||||||
):
|
|
||||||
if calculated_ratio < minimum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
if calculated_ratio > maximum_ratio:
|
|
||||||
raise TypeError(
|
|
||||||
f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
|
|
||||||
)
|
|
||||||
return aspect_ratio
|
|
||||||
|
|
||||||
|
|
||||||
def mimetype_to_extension(mime_type: str) -> str:
|
|
||||||
"""Converts a MIME type to a file extension."""
|
|
||||||
return mime_type.split("/")[-1].lower()
|
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_bytesio(
|
|
||||||
url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None
|
|
||||||
) -> BytesIO:
|
|
||||||
"""Downloads content from a URL using requests and returns it as BytesIO.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to download.
|
|
||||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
BytesIO object containing the downloaded content.
|
|
||||||
"""
|
|
||||||
headers = {}
|
|
||||||
if url.startswith("/proxy/"):
|
|
||||||
url = str(args.comfy_api_base).rstrip("/") + url
|
|
||||||
auth_token = auth_kwargs.get("auth_token")
|
|
||||||
comfy_api_key = auth_kwargs.get("comfy_api_key")
|
|
||||||
if auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {auth_token}"
|
|
||||||
elif comfy_api_key:
|
|
||||||
headers["X-API-KEY"] = comfy_api_key
|
|
||||||
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
|
|
||||||
async with session.get(url, headers=headers) as resp:
|
|
||||||
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
|
|
||||||
return BytesIO(await resp.read())
|
|
||||||
|
|
||||||
|
|
||||||
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
|
||||||
"""Converts image data from BytesIO to a torch.Tensor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_bytesio: BytesIO object containing the image data.
|
|
||||||
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A torch.Tensor representing the image (1, H, W, C).
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
|
||||||
ValueError: If the specified mode is invalid.
|
|
||||||
"""
|
|
||||||
image = Image.open(image_bytesio)
|
|
||||||
image = image.convert(mode)
|
|
||||||
image_array = np.array(image).astype(np.float32) / 255.0
|
|
||||||
return torch.from_numpy(image_array).unsqueeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
|
|
||||||
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
|
||||||
image_bytesio = await download_url_to_bytesio(url, timeout)
|
|
||||||
return bytesio_to_image_tensor(image_bytesio)
|
|
||||||
|
|
||||||
|
|
||||||
def process_image_response(response_content: bytes | str) -> torch.Tensor:
|
|
||||||
"""Uses content from a Response object and converts it to a torch.Tensor"""
|
|
||||||
return bytesio_to_image_tensor(BytesIO(response_content))
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
|
||||||
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
|
||||||
if len(image.shape) > 3:
|
|
||||||
image = image[0]
|
|
||||||
# TODO: remove alpha if not allowed and present
|
|
||||||
input_tensor = image.cpu()
|
|
||||||
input_tensor = downscale_image_tensor(
|
|
||||||
input_tensor.unsqueeze(0), total_pixels=total_pixels
|
|
||||||
).squeeze()
|
|
||||||
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
|
||||||
img = Image.fromarray(image_np)
|
|
||||||
return img
|
|
||||||
|
|
||||||
|
|
||||||
def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
|
||||||
"""Converts a PIL Image to a BytesIO object."""
|
|
||||||
if not mime_type:
|
|
||||||
mime_type = "image/png"
|
|
||||||
|
|
||||||
img_byte_arr = io.BytesIO()
|
|
||||||
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
|
||||||
pil_format = mime_type.split("/")[-1].upper()
|
|
||||||
if pil_format == "JPG":
|
|
||||||
pil_format = "JPEG"
|
|
||||||
img.save(img_byte_arr, format=pil_format)
|
|
||||||
img_byte_arr.seek(0)
|
|
||||||
return img_byte_arr
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_bytesio(
|
|
||||||
image: torch.Tensor,
|
|
||||||
name: Optional[str] = None,
|
|
||||||
total_pixels: int = 2048 * 2048,
|
|
||||||
mime_type: str = "image/png",
|
|
||||||
) -> BytesIO:
|
|
||||||
"""Converts a torch.Tensor image to a named BytesIO object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: Input torch.Tensor image.
|
|
||||||
name: Optional filename for the BytesIO object.
|
|
||||||
total_pixels: Maximum total pixels for potential downscaling.
|
|
||||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Named BytesIO object containing the image data.
|
|
||||||
"""
|
|
||||||
if not mime_type:
|
|
||||||
mime_type = "image/png"
|
|
||||||
|
|
||||||
pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
|
|
||||||
img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
|
||||||
img_binary.name = (
|
|
||||||
f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
|
||||||
)
|
|
||||||
return img_binary
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_base64_string(
|
|
||||||
image_tensor: torch.Tensor,
|
|
||||||
total_pixels: int = 2048 * 2048,
|
|
||||||
mime_type: str = "image/png",
|
|
||||||
) -> str:
|
|
||||||
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_tensor: Input torch.Tensor image.
|
|
||||||
total_pixels: Maximum total pixels for potential downscaling.
|
|
||||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Base64 encoded string of the image.
|
|
||||||
"""
|
|
||||||
pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
|
||||||
img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
|
|
||||||
img_bytes = img_byte_arr.getvalue()
|
|
||||||
# Encode bytes to base64 string
|
|
||||||
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
|
||||||
return base64_encoded_string
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_to_data_uri(
|
|
||||||
image_tensor: torch.Tensor,
|
|
||||||
total_pixels: int = 2048 * 2048,
|
|
||||||
mime_type: str = "image/png",
|
|
||||||
) -> str:
|
|
||||||
"""Converts a tensor image to a Data URI string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_tensor: Input torch.Tensor image.
|
|
||||||
total_pixels: Maximum total pixels for potential downscaling.
|
|
||||||
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Data URI string (e.g., 'data:image/png;base64,...').
|
|
||||||
"""
|
|
||||||
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_base64_string(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a base64 string."""
|
|
||||||
with open(filepath, "rb") as f:
|
|
||||||
file_content = f.read()
|
|
||||||
return base64.b64encode(file_content).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def text_filepath_to_data_uri(filepath: str) -> str:
|
|
||||||
"""Converts a text file to a data URI."""
|
|
||||||
base64_string = text_filepath_to_base64_string(filepath)
|
|
||||||
mime_type, _ = mimetypes.guess_type(filepath)
|
|
||||||
if mime_type is None:
|
|
||||||
mime_type = "application/octet-stream"
|
|
||||||
return f"data:{mime_type};base64,{base64_string}"
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_file_to_comfyapi(
|
|
||||||
file_bytes_io: BytesIO,
|
|
||||||
filename: str,
|
|
||||||
upload_mime_type: Optional[str],
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Uploads a single file to ComfyUI API and returns its download URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_bytes_io: BytesIO object containing the file data.
|
|
||||||
filename: The filename of the file.
|
|
||||||
upload_mime_type: MIME type of the file.
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The download URL for the uploaded file.
|
|
||||||
"""
|
|
||||||
if upload_mime_type is None:
|
|
||||||
request_object = UploadRequest(file_name=filename)
|
|
||||||
else:
|
|
||||||
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/customers/storage",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=UploadRequest,
|
|
||||||
response_model=UploadResponse,
|
|
||||||
),
|
|
||||||
request=request_object,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
response: UploadResponse = await operation.execute()
|
|
||||||
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
|
|
||||||
return response.download_url
|
|
||||||
|
|
||||||
|
|
||||||
def video_to_base64_string(
|
|
||||||
video: VideoInput,
|
|
||||||
container_format: VideoContainer = None,
|
|
||||||
codec: VideoCodec = None
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Converts a video input to a base64 string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video: The video input to convert
|
|
||||||
container_format: Optional container format to use (defaults to video.container if available)
|
|
||||||
codec: Optional codec to use (defaults to video.codec if available)
|
|
||||||
"""
|
|
||||||
video_bytes_io = io.BytesIO()
|
|
||||||
|
|
||||||
# Use provided format/codec if specified, otherwise use video's own if available
|
|
||||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
|
||||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
|
||||||
|
|
||||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
|
||||||
video_bytes_io.seek(0)
|
|
||||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_video_to_comfyapi(
|
|
||||||
video: VideoInput,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
container: VideoContainer = VideoContainer.MP4,
|
|
||||||
codec: VideoCodec = VideoCodec.H264,
|
|
||||||
max_duration: Optional[int] = None,
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Uploads a single video to ComfyUI API and returns its download URL.
|
|
||||||
Uses the specified container and codec for saving the video before upload.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video: VideoInput object (Comfy VIDEO type).
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
container: The video container format to use (default: MP4).
|
|
||||||
codec: The video codec to use (default: H264).
|
|
||||||
max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The download URL for the uploaded video file.
|
|
||||||
"""
|
|
||||||
if max_duration is not None:
|
|
||||||
try:
|
|
||||||
actual_duration = video.duration_seconds
|
|
||||||
if actual_duration is not None and actual_duration > max_duration:
|
|
||||||
raise ValueError(
|
|
||||||
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Error getting video duration: {e}")
|
|
||||||
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
|
||||||
|
|
||||||
upload_mime_type = f"video/{container.value.lower()}"
|
|
||||||
filename = f"uploaded_video.{container.value.lower()}"
|
|
||||||
|
|
||||||
# Convert VideoInput to BytesIO using specified container/codec
|
|
||||||
video_bytes_io = io.BytesIO()
|
|
||||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
|
||||||
video_bytes_io.seek(0)
|
|
||||||
|
|
||||||
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Contiguous numpy array of the audio waveform. If the audio was batched,
|
|
||||||
the first item is taken.
|
|
||||||
"""
|
|
||||||
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
|
||||||
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
|
||||||
|
|
||||||
# If batch is > 1, take first item
|
|
||||||
if waveform.shape[0] > 1:
|
|
||||||
waveform = waveform[0]
|
|
||||||
|
|
||||||
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
|
||||||
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
|
||||||
if audio_data_np.dtype != np.float32:
|
|
||||||
audio_data_np = audio_data_np.astype(np.float32)
|
|
||||||
|
|
||||||
return audio_data_np
|
|
||||||
|
|
||||||
|
|
||||||
def audio_ndarray_to_bytesio(
|
|
||||||
audio_data_np: np.ndarray,
|
|
||||||
sample_rate: int,
|
|
||||||
container_format: str = "mp4",
|
|
||||||
codec_name: str = "aac",
|
|
||||||
) -> BytesIO:
|
|
||||||
"""
|
|
||||||
Encodes a numpy array of audio data into a BytesIO object.
|
|
||||||
"""
|
|
||||||
audio_bytes_io = io.BytesIO()
|
|
||||||
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
|
||||||
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
|
||||||
frame = av.AudioFrame.from_ndarray(
|
|
||||||
audio_data_np,
|
|
||||||
format="fltp",
|
|
||||||
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
|
||||||
)
|
|
||||||
frame.sample_rate = sample_rate
|
|
||||||
frame.pts = 0
|
|
||||||
|
|
||||||
for packet in audio_stream.encode(frame):
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
# Flush stream
|
|
||||||
for packet in audio_stream.encode(None):
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
audio_bytes_io.seek(0)
|
|
||||||
return audio_bytes_io
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_audio_to_comfyapi(
|
|
||||||
audio: AudioInput,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
container_format: str = "mp4",
|
|
||||||
codec_name: str = "aac",
|
|
||||||
mime_type: str = "audio/mp4",
|
|
||||||
filename: str = "uploaded_audio.mp4",
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Uploads a single audio input to ComfyUI API and returns its download URL.
|
|
||||||
Encodes the raw waveform into the specified format before uploading.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The download URL for the uploaded audio file.
|
|
||||||
"""
|
|
||||||
sample_rate: int = audio["sample_rate"]
|
|
||||||
waveform: torch.Tensor = audio["waveform"]
|
|
||||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
|
||||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
|
||||||
audio_data_np, sample_rate, container_format, codec_name
|
|
||||||
)
|
|
||||||
|
|
||||||
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
|
||||||
if wav.dtype.is_floating_point:
|
|
||||||
return wav
|
|
||||||
elif wav.dtype == torch.int16:
|
|
||||||
return wav.float() / (2 ** 15)
|
|
||||||
elif wav.dtype == torch.int32:
|
|
||||||
return wav.float() / (2 ** 31)
|
|
||||||
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
|
||||||
|
|
||||||
|
|
||||||
def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict:
|
|
||||||
"""
|
|
||||||
Decode any common audio container from bytes using PyAV and return
|
|
||||||
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
|
||||||
"""
|
|
||||||
with av.open(io.BytesIO(audio_bytes)) as af:
|
|
||||||
if not af.streams.audio:
|
|
||||||
raise ValueError("No audio stream found in response.")
|
|
||||||
stream = af.streams.audio[0]
|
|
||||||
|
|
||||||
in_sr = int(stream.codec_context.sample_rate)
|
|
||||||
out_sr = in_sr
|
|
||||||
|
|
||||||
frames: list[torch.Tensor] = []
|
|
||||||
n_channels = stream.channels or 1
|
|
||||||
|
|
||||||
for frame in af.decode(streams=stream.index):
|
|
||||||
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
|
||||||
buf = torch.from_numpy(arr)
|
|
||||||
if buf.ndim == 1:
|
|
||||||
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
|
||||||
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
|
||||||
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
|
||||||
elif buf.shape[0] != n_channels:
|
|
||||||
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
|
||||||
frames.append(buf)
|
|
||||||
|
|
||||||
if not frames:
|
|
||||||
raise ValueError("Decoded zero audio frames.")
|
|
||||||
|
|
||||||
wav = torch.cat(frames, dim=1) # [C, T]
|
|
||||||
wav = f32_pcm(wav)
|
|
||||||
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
|
||||||
|
|
||||||
|
|
||||||
def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO:
|
|
||||||
waveform = audio["waveform"].cpu()
|
|
||||||
|
|
||||||
output_buffer = io.BytesIO()
|
|
||||||
output_container = av.open(output_buffer, mode='w', format="mp3")
|
|
||||||
|
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
|
||||||
out_stream.bit_rate = 320000
|
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
|
||||||
frame.sample_rate = audio["sample_rate"]
|
|
||||||
frame.pts = 0
|
|
||||||
output_container.mux(out_stream.encode(frame))
|
|
||||||
output_container.mux(out_stream.encode(None))
|
|
||||||
output_container.close()
|
|
||||||
output_buffer.seek(0)
|
|
||||||
return output_buffer
|
|
||||||
|
|
||||||
|
|
||||||
def audio_to_base64_string(
|
|
||||||
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
|
|
||||||
) -> str:
|
|
||||||
"""Converts an audio input to a base64 string."""
|
|
||||||
sample_rate: int = audio["sample_rate"]
|
|
||||||
waveform: torch.Tensor = audio["waveform"]
|
|
||||||
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
|
||||||
audio_bytes_io = audio_ndarray_to_bytesio(
|
|
||||||
audio_data_np, sample_rate, container_format, codec_name
|
|
||||||
)
|
|
||||||
audio_bytes = audio_bytes_io.getvalue()
|
|
||||||
return base64.b64encode(audio_bytes).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
async def upload_images_to_comfyapi(
|
|
||||||
image: torch.Tensor,
|
|
||||||
max_images=8,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
mime_type: Optional[str] = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
Uploads images to ComfyUI API and returns download URLs.
|
|
||||||
To upload multiple images, stack them in the batch dimension first.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: Input torch.Tensor image.
|
|
||||||
max_images: Maximum number of images to upload.
|
|
||||||
auth_kwargs: Optional authentication token(s).
|
|
||||||
mime_type: Optional MIME type for the image.
|
|
||||||
"""
|
|
||||||
# if batch, try to upload each file if max_images is greater than 0
|
|
||||||
download_urls: list[str] = []
|
|
||||||
is_batch = len(image.shape) > 3
|
|
||||||
batch_len = image.shape[0] if is_batch else 1
|
|
||||||
|
|
||||||
for idx in range(min(batch_len, max_images)):
|
|
||||||
tensor = image[idx] if is_batch else image
|
|
||||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
|
||||||
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
|
|
||||||
download_urls.append(url)
|
|
||||||
return download_urls
|
|
||||||
|
|
||||||
|
|
||||||
def resize_mask_to_image(
|
|
||||||
mask: torch.Tensor,
|
|
||||||
image: torch.Tensor,
|
|
||||||
upscale_method="nearest-exact",
|
|
||||||
crop="disabled",
|
|
||||||
allow_gradient=True,
|
|
||||||
add_channel_dim=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
|
|
||||||
"""
|
|
||||||
_, H, W, _ = image.shape
|
|
||||||
mask = mask.unsqueeze(-1)
|
|
||||||
mask = mask.movedim(-1, 1)
|
|
||||||
mask = common_upscale(
|
|
||||||
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
|
|
||||||
)
|
|
||||||
mask = mask.movedim(1, -1)
|
|
||||||
if not add_channel_dim:
|
|
||||||
mask = mask.squeeze(-1)
|
|
||||||
if not allow_gradient:
|
|
||||||
mask = (mask > 0.5).float()
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def validate_string(
|
|
||||||
string: str,
|
|
||||||
strip_whitespace=True,
|
|
||||||
field_name="prompt",
|
|
||||||
min_length=None,
|
|
||||||
max_length=None,
|
|
||||||
):
|
|
||||||
if string is None:
|
|
||||||
raise Exception(f"Field '{field_name}' cannot be empty.")
|
|
||||||
if strip_whitespace:
|
|
||||||
string = string.strip()
|
|
||||||
if min_length and len(string) < min_length:
|
|
||||||
raise Exception(
|
|
||||||
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
|
||||||
)
|
|
||||||
if max_length and len(string) > max_length:
|
|
||||||
raise Exception(
|
|
||||||
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def image_tensor_pair_to_batch(
|
|
||||||
image1: torch.Tensor, image2: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Converts a pair of image tensors to a batch tensor.
|
|
||||||
If the images are not the same size, the smaller image is resized to
|
|
||||||
match the larger image.
|
|
||||||
"""
|
|
||||||
if image1.shape[1:] != image2.shape[1:]:
|
|
||||||
image2 = common_upscale(
|
|
||||||
image2.movedim(-1, 1),
|
|
||||||
image1.shape[2],
|
|
||||||
image1.shape[1],
|
|
||||||
"bilinear",
|
|
||||||
"center",
|
|
||||||
).movedim(1, -1)
|
|
||||||
return torch.cat((image1, image2), dim=0)
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: filtered-openapi.yaml
|
|
||||||
# timestamp: 2025-04-29T23:44:54+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from . import PixverseDto
|
|
||||||
|
|
||||||
|
|
||||||
class ResponseData(BaseModel):
|
|
||||||
ErrCode: Optional[int] = None
|
|
||||||
ErrMsg: Optional[str] = None
|
|
||||||
Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
# generated by datamodel-codegen:
|
|
||||||
# filename: filtered-openapi.yaml
|
|
||||||
# timestamp: 2025-04-29T23:44:54+00:00
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPII2VResp(BaseModel):
|
|
||||||
video_id: Optional[int] = Field(None, description='Video_id')
|
|
||||||
|
|
||||||
|
|
||||||
class V2OpenAPIT2VReq(BaseModel):
|
|
||||||
aspect_ratio: str = Field(
|
|
||||||
..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
|
|
||||||
)
|
|
||||||
duration: int = Field(
|
|
||||||
...,
|
|
||||||
description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
|
|
||||||
examples=[5],
|
|
||||||
)
|
|
||||||
model: str = Field(
|
|
||||||
..., description='Model version (only supports v3.5)', examples=['v3.5']
|
|
||||||
)
|
|
||||||
motion_mode: Optional[str] = Field(
|
|
||||||
'normal',
|
|
||||||
description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
|
|
||||||
examples=['normal'],
|
|
||||||
)
|
|
||||||
negative_prompt: Optional[str] = Field(
|
|
||||||
None, description='Negative prompt\n', max_length=2048
|
|
||||||
)
|
|
||||||
prompt: str = Field(..., description='Prompt', max_length=2048)
|
|
||||||
quality: str = Field(
|
|
||||||
...,
|
|
||||||
description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
|
|
||||||
examples=['540p'],
|
|
||||||
)
|
|
||||||
seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
|
|
||||||
style: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
|
|
||||||
examples=['anime'],
|
|
||||||
)
|
|
||||||
template_id: Optional[int] = Field(
|
|
||||||
None,
|
|
||||||
description='Template ID (template_id must be activated before use)',
|
|
||||||
examples=[302325299692608],
|
|
||||||
)
|
|
||||||
water_mark: Optional[bool] = Field(
|
|
||||||
False,
|
|
||||||
description='Watermark (true: add watermark, false: no watermark)',
|
|
||||||
examples=[False],
|
|
||||||
)
|
|
||||||
@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel):
|
|||||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||||
|
|
||||||
|
|
||||||
class BFLFluxCannyImageRequest(BaseModel):
|
|
||||||
prompt: str = Field(..., description='Text prompt for image generation')
|
|
||||||
prompt_upsampling: Optional[bool] = Field(
|
|
||||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
|
||||||
)
|
|
||||||
canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
|
|
||||||
canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
|
|
||||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
|
||||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
|
||||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
|
||||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
|
||||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
|
||||||
)
|
|
||||||
output_format: Optional[BFLOutputFormat] = Field(
|
|
||||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
|
||||||
)
|
|
||||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
|
||||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
|
||||||
|
|
||||||
|
|
||||||
class BFLFluxDepthImageRequest(BaseModel):
|
|
||||||
prompt: str = Field(..., description='Text prompt for image generation')
|
|
||||||
prompt_upsampling: Optional[bool] = Field(
|
|
||||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
|
||||||
)
|
|
||||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
|
||||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
|
||||||
guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
|
|
||||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
|
||||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
|
||||||
)
|
|
||||||
output_format: Optional[BFLOutputFormat] = Field(
|
|
||||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
|
||||||
)
|
|
||||||
control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
|
|
||||||
preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
|
|
||||||
|
|
||||||
|
|
||||||
class BFLFluxProGenerateRequest(BaseModel):
|
class BFLFluxProGenerateRequest(BaseModel):
|
||||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||||
prompt_upsampling: Optional[bool] = Field(
|
prompt_upsampling: Optional[bool] = Field(
|
||||||
@ -160,15 +122,8 @@ class BFLStatus(str, Enum):
|
|||||||
error = "Error"
|
error = "Error"
|
||||||
|
|
||||||
|
|
||||||
class BFLFluxProStatusResponse(BaseModel):
|
class BFLFluxStatusResponse(BaseModel):
|
||||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||||
status: BFLStatus = Field(..., description="The status of the task.")
|
status: BFLStatus = Field(..., description="The status of the task.")
|
||||||
result: Optional[Dict[str, Any]] = Field(
|
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||||
None, description="The result of the task (null if not completed)."
|
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||||
)
|
|
||||||
progress: confloat(ge=0.0, le=1.0) = Field(
|
|
||||||
..., description="The progress of the task (0.0 to 1.0)."
|
|
||||||
)
|
|
||||||
details: Optional[Dict[str, Any]] = Field(
|
|
||||||
None, description="Additional details about the task (null if not available)."
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,963 +0,0 @@
|
|||||||
"""
|
|
||||||
API Client Framework for api.comfy.org.
|
|
||||||
|
|
||||||
This module provides a flexible framework for making API requests from ComfyUI nodes.
|
|
||||||
It supports both synchronous and asynchronous API operations with proper type validation.
|
|
||||||
|
|
||||||
Key Components:
|
|
||||||
--------------
|
|
||||||
1. ApiClient - Handles HTTP requests with authentication and error handling
|
|
||||||
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
|
|
||||||
3. ApiOperation - Executes a single synchronous API operation
|
|
||||||
|
|
||||||
Usage Examples:
|
|
||||||
--------------
|
|
||||||
|
|
||||||
# Example 1: Synchronous API Operation
|
|
||||||
# ------------------------------------
|
|
||||||
# For a simple API call that returns the result immediately:
|
|
||||||
|
|
||||||
# 1. Create the API client
|
|
||||||
api_client = ApiClient(
|
|
||||||
base_url="https://api.example.com",
|
|
||||||
auth_token="your_auth_token_here",
|
|
||||||
comfy_api_key="your_comfy_api_key_here",
|
|
||||||
timeout=30.0,
|
|
||||||
verify_ssl=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Define the endpoint
|
|
||||||
user_info_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/users/me",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest, # No request body needed
|
|
||||||
response_model=UserProfile, # Pydantic model for the response
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Create the request object
|
|
||||||
request = EmptyRequest()
|
|
||||||
|
|
||||||
# 4. Create and execute the operation
|
|
||||||
operation = ApiOperation(
|
|
||||||
endpoint=user_info_endpoint,
|
|
||||||
request=request
|
|
||||||
)
|
|
||||||
user_profile = await operation.execute(client=api_client) # Returns immediately with the result
|
|
||||||
|
|
||||||
|
|
||||||
# Example 2: Asynchronous API Operation with Polling
|
|
||||||
# -------------------------------------------------
|
|
||||||
# For an API that starts a task and requires polling for completion:
|
|
||||||
|
|
||||||
# 1. Define the endpoints (initial request and polling)
|
|
||||||
generate_image_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/images/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=ImageGenerationRequest,
|
|
||||||
response_model=TaskCreatedResponse,
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
check_task_endpoint = ApiEndpoint(
|
|
||||||
path="/v1/tasks/{task_id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=ImageGenerationResult,
|
|
||||||
query_params=None
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Create the request object
|
|
||||||
request = ImageGenerationRequest(
|
|
||||||
prompt="a beautiful sunset over mountains",
|
|
||||||
width=1024,
|
|
||||||
height=1024,
|
|
||||||
num_images=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Create and execute the polling operation
|
|
||||||
operation = PollingOperation(
|
|
||||||
initial_endpoint=generate_image_endpoint,
|
|
||||||
initial_request=request,
|
|
||||||
poll_endpoint=check_task_endpoint,
|
|
||||||
task_id_field="task_id",
|
|
||||||
status_field="status",
|
|
||||||
completed_statuses=["completed"],
|
|
||||||
failed_statuses=["failed", "error"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# This will make the initial request and then poll until completion
|
|
||||||
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import socket
|
|
||||||
from aiohttp.client_exceptions import ClientError, ClientResponseError
|
|
||||||
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
|
|
||||||
from enum import Enum
|
|
||||||
import json
|
|
||||||
from urllib.parse import urljoin, urlparse
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
import uuid # For generating unique operation IDs
|
|
||||||
|
|
||||||
from server import PromptServer
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from comfy import utils
|
|
||||||
from . import request_logger
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
|
||||||
R = TypeVar("R", bound=BaseModel)
|
|
||||||
P = TypeVar("P", bound=BaseModel) # For poll response
|
|
||||||
|
|
||||||
PROGRESS_BAR_MAX = 100
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkError(Exception):
|
|
||||||
"""Base exception for network-related errors with diagnostic information."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LocalNetworkError(NetworkError):
|
|
||||||
"""Exception raised when local network connectivity issues are detected."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ApiServerError(NetworkError):
|
|
||||||
"""Exception raised when the API server is unreachable but internet is working."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EmptyRequest(BaseModel):
|
|
||||||
"""Base class for empty request bodies.
|
|
||||||
For GET requests, fields will be sent as query parameters."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UploadRequest(BaseModel):
|
|
||||||
file_name: str = Field(..., description="Filename to upload")
|
|
||||||
content_type: Optional[str] = Field(
|
|
||||||
None,
|
|
||||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UploadResponse(BaseModel):
|
|
||||||
download_url: str = Field(..., description="URL to GET uploaded file")
|
|
||||||
upload_url: str = Field(..., description="URL to PUT file to upload")
|
|
||||||
|
|
||||||
|
|
||||||
class HttpMethod(str, Enum):
|
|
||||||
GET = "GET"
|
|
||||||
POST = "POST"
|
|
||||||
PUT = "PUT"
|
|
||||||
DELETE = "DELETE"
|
|
||||||
PATCH = "PATCH"
|
|
||||||
|
|
||||||
|
|
||||||
class ApiClient:
|
|
||||||
"""
|
|
||||||
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
base_url: str,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
timeout: float = 3600.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
retry_status_codes: Optional[Tuple[int, ...]] = None,
|
|
||||||
session: Optional[aiohttp.ClientSession] = None,
|
|
||||||
):
|
|
||||||
self.base_url = base_url
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
|
|
||||||
# 500, 502, 503, 504 (Server Errors)
|
|
||||||
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
|
|
||||||
self._session: Optional[aiohttp.ClientSession] = session
|
|
||||||
self._owns_session = session is None # Track if we have to close it
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_operation_id(path: str) -> str:
|
|
||||||
"""Generates a unique operation ID for logging."""
|
|
||||||
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_json_payload_args(
|
|
||||||
data: Optional[Dict[str, Any]] = None,
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"json": data,
|
|
||||||
"headers": headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _create_form_data_args(
|
|
||||||
self,
|
|
||||||
data: Dict[str, Any] | None,
|
|
||||||
files: Dict[str, Any] | None,
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
if headers and "Content-Type" in headers:
|
|
||||||
del headers["Content-Type"]
|
|
||||||
|
|
||||||
if multipart_parser and data:
|
|
||||||
data = multipart_parser(data)
|
|
||||||
|
|
||||||
if isinstance(data, aiohttp.FormData):
|
|
||||||
form = data # If the parser already returned a FormData, pass it through
|
|
||||||
else:
|
|
||||||
form = aiohttp.FormData(default_to_multipart=True)
|
|
||||||
if data: # regular text fields
|
|
||||||
for k, v in data.items():
|
|
||||||
if v is None:
|
|
||||||
continue # aiohttp fails to serialize "None" values
|
|
||||||
# aiohttp expects strings or bytes; convert enums etc.
|
|
||||||
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
|
||||||
|
|
||||||
if files:
|
|
||||||
file_iter = files if isinstance(files, list) else files.items()
|
|
||||||
for field_name, file_obj in file_iter:
|
|
||||||
if file_obj is None:
|
|
||||||
continue # aiohttp fails to serialize "None" values
|
|
||||||
# file_obj can be (filename, bytes/io.BytesIO, content_type) tuple
|
|
||||||
if isinstance(file_obj, tuple):
|
|
||||||
filename, file_value, content_type = self._unpack_tuple(file_obj)
|
|
||||||
else:
|
|
||||||
file_value = file_obj
|
|
||||||
filename = getattr(file_obj, "name", field_name)
|
|
||||||
content_type = "application/octet-stream"
|
|
||||||
|
|
||||||
form.add_field(
|
|
||||||
name=field_name,
|
|
||||||
value=file_value,
|
|
||||||
filename=filename,
|
|
||||||
content_type=content_type,
|
|
||||||
)
|
|
||||||
return {"data": form, "headers": headers or {}}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_urlencoded_form_data_args(
|
|
||||||
data: Dict[str, Any],
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
headers = headers or {}
|
|
||||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
|
||||||
return {
|
|
||||||
"data": data,
|
|
||||||
"headers": headers,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_headers(self) -> Dict[str, str]:
|
|
||||||
"""Get headers for API requests, including authentication if available"""
|
|
||||||
headers = {"Content-Type": "application/json", "Accept": "application/json"}
|
|
||||||
|
|
||||||
if self.auth_token:
|
|
||||||
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
||||||
elif self.comfy_api_key:
|
|
||||||
headers["X-API-KEY"] = self.comfy_api_key
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
|
|
||||||
"""
|
|
||||||
Check connectivity to determine if network issues are local or server-related.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_url: URL to check connectivity to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with connectivity status details
|
|
||||||
"""
|
|
||||||
results = {
|
|
||||||
"internet_accessible": False,
|
|
||||||
"api_accessible": False,
|
|
||||||
"is_local_issue": False,
|
|
||||||
"is_api_issue": False,
|
|
||||||
}
|
|
||||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
try:
|
|
||||||
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp:
|
|
||||||
results["internet_accessible"] = resp.status < 500
|
|
||||||
except (ClientError, asyncio.TimeoutError, socket.gaierror):
|
|
||||||
results["is_local_issue"] = True
|
|
||||||
return results # cannot reach the internet – early exit
|
|
||||||
|
|
||||||
# Now check API health endpoint
|
|
||||||
parsed = urlparse(target_url)
|
|
||||||
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
|
||||||
try:
|
|
||||||
async with session.get(health_url, ssl=self.verify_ssl) as resp:
|
|
||||||
results["api_accessible"] = resp.status < 500
|
|
||||||
except ClientError:
|
|
||||||
pass # leave as False
|
|
||||||
|
|
||||||
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"]
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def request(
|
|
||||||
self,
|
|
||||||
method: str,
|
|
||||||
path: str,
|
|
||||||
params: Optional[Dict[str, Any]] = None,
|
|
||||||
data: Optional[Dict[str, Any]] = None,
|
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
|
||||||
headers: Optional[Dict[str, str]] = None,
|
|
||||||
content_type: str = "application/json",
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
retry_count: int = 0, # Used internally for tracking retries
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Make an HTTP request to the API with automatic retries for transient errors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: HTTP method (GET, POST, etc.)
|
|
||||||
path: API endpoint path (will be joined with base_url)
|
|
||||||
params: Query parameters
|
|
||||||
data: body data
|
|
||||||
files: Files to upload
|
|
||||||
headers: Additional headers
|
|
||||||
content_type: Content type of the request. Defaults to application/json.
|
|
||||||
retry_count: Internal parameter for tracking retries, do not set manually
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed JSON response
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
LocalNetworkError: If local network connectivity issues are detected
|
|
||||||
ApiServerError: If the API server is unreachable but internet is working
|
|
||||||
Exception: For other request failures
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Build full URL and merge headers
|
|
||||||
relative_path = path.lstrip("/")
|
|
||||||
url = urljoin(self.base_url, relative_path)
|
|
||||||
self._check_auth(self.auth_token, self.comfy_api_key)
|
|
||||||
|
|
||||||
request_headers = self.get_headers()
|
|
||||||
if headers:
|
|
||||||
request_headers.update(headers)
|
|
||||||
if files:
|
|
||||||
request_headers.pop("Content-Type", None)
|
|
||||||
if params:
|
|
||||||
params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values
|
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] Request Headers: {request_headers}")
|
|
||||||
logging.debug(f"[DEBUG] Files: {files}")
|
|
||||||
logging.debug(f"[DEBUG] Params: {params}")
|
|
||||||
logging.debug(f"[DEBUG] Data: {data}")
|
|
||||||
|
|
||||||
if content_type == "application/x-www-form-urlencoded":
|
|
||||||
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers)
|
|
||||||
elif content_type == "multipart/form-data":
|
|
||||||
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser)
|
|
||||||
else:
|
|
||||||
payload_args = self._create_json_payload_args(data, request_headers)
|
|
||||||
|
|
||||||
operation_id = self._generate_operation_id(path)
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
request_headers=request_headers,
|
|
||||||
request_params=params,
|
|
||||||
request_data=data if content_type == "application/json" else "[form-data or other]",
|
|
||||||
)
|
|
||||||
|
|
||||||
session = await self._get_session()
|
|
||||||
try:
|
|
||||||
async with session.request(
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
params=params,
|
|
||||||
ssl=self.verify_ssl,
|
|
||||||
**payload_args,
|
|
||||||
) as resp:
|
|
||||||
if resp.status >= 400:
|
|
||||||
try:
|
|
||||||
error_data = await resp.json()
|
|
||||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
||||||
error_data = await resp.text()
|
|
||||||
|
|
||||||
return await self._handle_http_error(
|
|
||||||
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data),
|
|
||||||
operation_id,
|
|
||||||
method,
|
|
||||||
url,
|
|
||||||
params,
|
|
||||||
data,
|
|
||||||
files,
|
|
||||||
headers,
|
|
||||||
content_type,
|
|
||||||
multipart_parser,
|
|
||||||
retry_count=retry_count,
|
|
||||||
response_content=error_data,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Success – parse JSON (safely) and log
|
|
||||||
try:
|
|
||||||
payload = await resp.json()
|
|
||||||
response_content_to_log = payload
|
|
||||||
except (aiohttp.ContentTypeError, json.JSONDecodeError):
|
|
||||||
payload = {}
|
|
||||||
response_content_to_log = await resp.text()
|
|
||||||
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=method,
|
|
||||||
request_url=url,
|
|
||||||
response_status_code=resp.status,
|
|
||||||
response_headers=dict(resp.headers),
|
|
||||||
response_content=response_content_to_log,
|
|
||||||
)
|
|
||||||
return payload
|
|
||||||
|
|
||||||
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e:
|
|
||||||
# Treat as *connection* problem – optionally retry, else escalate
|
|
||||||
if retry_count < self.max_retries:
|
|
||||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
|
||||||
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1,
|
|
||||||
self.max_retries, str(e))
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
return await self.request(
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
params=params,
|
|
||||||
data=data,
|
|
||||||
files=files,
|
|
||||||
headers=headers,
|
|
||||||
content_type=content_type,
|
|
||||||
multipart_parser=multipart_parser,
|
|
||||||
retry_count=retry_count + 1,
|
|
||||||
)
|
|
||||||
# One final connectivity check for diagnostics
|
|
||||||
connectivity = await self._check_connectivity(self.base_url)
|
|
||||||
if connectivity["is_local_issue"]:
|
|
||||||
raise LocalNetworkError(
|
|
||||||
"Unable to connect to the API server due to local network issues. "
|
|
||||||
"Please check your internet connection and try again."
|
|
||||||
) from e
|
|
||||||
raise ApiServerError(
|
|
||||||
f"The API server at {self.base_url} is currently unreachable. "
|
|
||||||
f"The service may be experiencing issues. Please try again later."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _check_auth(auth_token, comfy_api_key):
|
|
||||||
"""Verify that an auth token is present or comfy_api_key is present"""
|
|
||||||
if auth_token is None and comfy_api_key is None:
|
|
||||||
raise Exception("Unauthorized: Please login first to use this node.")
|
|
||||||
return auth_token or comfy_api_key
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def upload_file(
|
|
||||||
upload_url: str,
|
|
||||||
file: io.BytesIO | str,
|
|
||||||
content_type: str | None = None,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
) -> aiohttp.ClientResponse:
|
|
||||||
"""Upload a file to the API with retry logic.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
upload_url: The URL to upload to
|
|
||||||
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
|
|
||||||
content_type: Optional mime type to set for the upload
|
|
||||||
max_retries: Maximum number of retry attempts
|
|
||||||
retry_delay: Initial delay between retries in seconds
|
|
||||||
retry_backoff_factor: Multiplier for the delay after each retry
|
|
||||||
"""
|
|
||||||
headers: Dict[str, str] = {}
|
|
||||||
skip_auto_headers: set[str] = set()
|
|
||||||
if content_type:
|
|
||||||
headers["Content-Type"] = content_type
|
|
||||||
else:
|
|
||||||
# tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status.
|
|
||||||
skip_auto_headers.add("Content-Type")
|
|
||||||
|
|
||||||
# Extract file bytes
|
|
||||||
if isinstance(file, io.BytesIO):
|
|
||||||
file.seek(0)
|
|
||||||
data = file.read()
|
|
||||||
elif isinstance(file, str):
|
|
||||||
with open(file, "rb") as f:
|
|
||||||
data = f.read()
|
|
||||||
else:
|
|
||||||
raise ValueError("File must be BytesIO or str path")
|
|
||||||
|
|
||||||
parsed = urlparse(upload_url)
|
|
||||||
basename = os.path.basename(parsed.path) or parsed.netloc or "upload"
|
|
||||||
operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}"
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
request_headers=headers,
|
|
||||||
request_data=f"[File data {len(data)} bytes]",
|
|
||||||
)
|
|
||||||
|
|
||||||
delay = retry_delay
|
|
||||||
for attempt in range(max_retries + 1):
|
|
||||||
try:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
async with session.put(
|
|
||||||
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers,
|
|
||||||
) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
response_status_code=resp.status,
|
|
||||||
response_headers=dict(resp.headers),
|
|
||||||
response_content="File uploaded successfully.",
|
|
||||||
)
|
|
||||||
return resp
|
|
||||||
except (ClientError, asyncio.TimeoutError) as e:
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method="PUT",
|
|
||||||
request_url=upload_url,
|
|
||||||
response_status_code=e.status if hasattr(e, "status") else None,
|
|
||||||
response_headers=dict(e.headers) if hasattr(e, "headers") else None,
|
|
||||||
response_content=None,
|
|
||||||
error_message=f"{type(e).__name__}: {str(e)}",
|
|
||||||
)
|
|
||||||
if attempt < max_retries:
|
|
||||||
logging.warning(
|
|
||||||
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e)
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
delay *= retry_backoff_factor
|
|
||||||
else:
|
|
||||||
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e
|
|
||||||
|
|
||||||
async def _handle_http_error(
|
|
||||||
self,
|
|
||||||
exc: ClientResponseError,
|
|
||||||
operation_id: str,
|
|
||||||
*req_meta,
|
|
||||||
retry_count: int,
|
|
||||||
response_content: dict | str = "",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
status_code = exc.status
|
|
||||||
if status_code == 401:
|
|
||||||
user_friendly = "Unauthorized: Please login first to use this node."
|
|
||||||
elif status_code == 402:
|
|
||||||
user_friendly = "Payment Required: Please add credits to your account to use this node."
|
|
||||||
elif status_code == 409:
|
|
||||||
user_friendly = "There is a problem with your account. Please contact support@comfy.org."
|
|
||||||
elif status_code == 429:
|
|
||||||
user_friendly = "Rate Limit Exceeded: Please try again later."
|
|
||||||
else:
|
|
||||||
if isinstance(response_content, dict):
|
|
||||||
if "error" in response_content and "message" in response_content["error"]:
|
|
||||||
user_friendly = f"API Error: {response_content['error']['message']}"
|
|
||||||
if "type" in response_content["error"]:
|
|
||||||
user_friendly += f" (Type: {response_content['error']['type']})"
|
|
||||||
else: # Handle cases where error is just a JSON dict with unknown format
|
|
||||||
user_friendly = f"API Error: {json.dumps(response_content)}"
|
|
||||||
else:
|
|
||||||
if len(response_content) < 200: # Arbitrary limit for display
|
|
||||||
user_friendly = f"API Error (raw): {response_content}"
|
|
||||||
else:
|
|
||||||
user_friendly = f"API Error (raw, status {response_content})"
|
|
||||||
|
|
||||||
request_logger.log_request_response(
|
|
||||||
operation_id=operation_id,
|
|
||||||
request_method=req_meta[0],
|
|
||||||
request_url=req_meta[1],
|
|
||||||
response_status_code=exc.status,
|
|
||||||
response_headers=dict(req_meta[5]) if req_meta[5] else None,
|
|
||||||
response_content=response_content,
|
|
||||||
error_message=f"HTTP Error {exc.status}",
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})")
|
|
||||||
if response_content:
|
|
||||||
logging.debug(f"[DEBUG] Response content: {response_content}")
|
|
||||||
|
|
||||||
# Retry if eligible
|
|
||||||
if status_code in self.retry_status_codes and retry_count < self.max_retries:
|
|
||||||
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
|
|
||||||
logging.warning(
|
|
||||||
"HTTP error %s. Retrying in %.2fs (%s/%s)",
|
|
||||||
status_code,
|
|
||||||
delay,
|
|
||||||
retry_count + 1,
|
|
||||||
self.max_retries,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
return await self.request(
|
|
||||||
req_meta[0], # method
|
|
||||||
req_meta[1].replace(self.base_url, ""), # path
|
|
||||||
params=req_meta[2],
|
|
||||||
data=req_meta[3],
|
|
||||||
files=req_meta[4],
|
|
||||||
headers=req_meta[5],
|
|
||||||
content_type=req_meta[6],
|
|
||||||
multipart_parser=req_meta[7],
|
|
||||||
retry_count=retry_count + 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise Exception(user_friendly) from exc
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _unpack_tuple(t):
|
|
||||||
"""Helper to normalise (filename, file, content_type) tuples."""
|
|
||||||
if len(t) == 3:
|
|
||||||
return t
|
|
||||||
elif len(t) == 2:
|
|
||||||
return t[0], t[1], "application/octet-stream"
|
|
||||||
else:
|
|
||||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
|
||||||
|
|
||||||
async def _get_session(self) -> aiohttp.ClientSession:
|
|
||||||
if self._session is None or self._session.closed:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
|
||||||
self._session = aiohttp.ClientSession(timeout=timeout)
|
|
||||||
self._owns_session = True
|
|
||||||
return self._session
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
if self._owns_session and self._session and not self._session.closed:
|
|
||||||
await self._session.close()
|
|
||||||
|
|
||||||
async def __aenter__(self) -> "ApiClient":
|
|
||||||
"""Allow usage as async‑context‑manager – ensures clean teardown"""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
|
||||||
await self.close()
|
|
||||||
|
|
||||||
|
|
||||||
class ApiEndpoint(Generic[T, R]):
|
|
||||||
"""Defines an API endpoint with its request and response types"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
path: str,
|
|
||||||
method: HttpMethod,
|
|
||||||
request_model: Type[T],
|
|
||||||
response_model: Type[R],
|
|
||||||
query_params: Optional[Dict[str, Any]] = None,
|
|
||||||
):
|
|
||||||
"""Initialize an API endpoint definition.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: The URL path for this endpoint, can include placeholders like {id}
|
|
||||||
method: The HTTP method to use (GET, POST, etc.)
|
|
||||||
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
|
|
||||||
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
|
|
||||||
query_params: Optional dictionary of query parameters to include in the request
|
|
||||||
"""
|
|
||||||
self.path = path
|
|
||||||
self.method = method
|
|
||||||
self.request_model = request_model
|
|
||||||
self.response_model = response_model
|
|
||||||
self.query_params = query_params or {}
|
|
||||||
|
|
||||||
|
|
||||||
class SynchronousOperation(Generic[T, R]):
|
|
||||||
"""Represents a single synchronous API operation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
endpoint: ApiEndpoint[T, R],
|
|
||||||
request: T,
|
|
||||||
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
|
||||||
timeout: float = 7200.0,
|
|
||||||
verify_ssl: bool = True,
|
|
||||||
content_type: str = "application/json",
|
|
||||||
multipart_parser: Callable | None = None,
|
|
||||||
max_retries: int = 3,
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
) -> None:
|
|
||||||
self.endpoint = endpoint
|
|
||||||
self.request = request
|
|
||||||
self.files = files
|
|
||||||
self.api_base: str = api_base or args.comfy_api_base
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
if auth_kwargs is not None:
|
|
||||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
|
||||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
|
||||||
self.timeout = timeout
|
|
||||||
self.verify_ssl = verify_ssl
|
|
||||||
self.content_type = content_type
|
|
||||||
self.multipart_parser = multipart_parser
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
|
|
||||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
|
||||||
owns_client = client is None
|
|
||||||
if owns_client:
|
|
||||||
client = ApiClient(
|
|
||||||
base_url=self.api_base,
|
|
||||||
auth_token=self.auth_token,
|
|
||||||
comfy_api_key=self.comfy_api_key,
|
|
||||||
timeout=self.timeout,
|
|
||||||
verify_ssl=self.verify_ssl,
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
retry_delay=self.retry_delay,
|
|
||||||
retry_backoff_factor=self.retry_backoff_factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
request_dict: Optional[Dict[str, Any]]
|
|
||||||
if isinstance(self.request, EmptyRequest):
|
|
||||||
request_dict = None
|
|
||||||
else:
|
|
||||||
request_dict = self.request.model_dump(exclude_none=True)
|
|
||||||
for k, v in list(request_dict.items()):
|
|
||||||
if isinstance(v, Enum):
|
|
||||||
request_dict[k] = v.value
|
|
||||||
|
|
||||||
logging.debug(
|
|
||||||
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
|
|
||||||
)
|
|
||||||
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
|
|
||||||
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
|
|
||||||
|
|
||||||
response_json = await client.request(
|
|
||||||
self.endpoint.method.value,
|
|
||||||
self.endpoint.path,
|
|
||||||
params=self.endpoint.query_params,
|
|
||||||
data=request_dict,
|
|
||||||
files=self.files,
|
|
||||||
content_type=self.content_type,
|
|
||||||
multipart_parser=self.multipart_parser,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
logging.debug("[DEBUG] RESPONSE DETAILS:")
|
|
||||||
logging.debug("[DEBUG] Status Code: 200 (Success)")
|
|
||||||
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}")
|
|
||||||
logging.debug("=" * 50)
|
|
||||||
|
|
||||||
parsed_response = self.endpoint.response_model.model_validate(response_json)
|
|
||||||
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}")
|
|
||||||
return parsed_response
|
|
||||||
finally:
|
|
||||||
if owns_client:
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
|
||||||
"""Enum for task status values"""
|
|
||||||
|
|
||||||
COMPLETED = "completed"
|
|
||||||
FAILED = "failed"
|
|
||||||
PENDING = "pending"
|
|
||||||
|
|
||||||
|
|
||||||
class PollingOperation(Generic[T, R]):
|
|
||||||
"""Represents an asynchronous API operation that requires polling for completion."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
poll_endpoint: ApiEndpoint[EmptyRequest, R],
|
|
||||||
completed_statuses: list[str],
|
|
||||||
failed_statuses: list[str],
|
|
||||||
status_extractor: Callable[[R], str],
|
|
||||||
progress_extractor: Callable[[R], float] | None = None,
|
|
||||||
result_url_extractor: Callable[[R], str] | None = None,
|
|
||||||
request: Optional[T] = None,
|
|
||||||
api_base: str | None = None,
|
|
||||||
auth_token: Optional[str] = None,
|
|
||||||
comfy_api_key: Optional[str] = None,
|
|
||||||
auth_kwargs: Optional[Dict[str, str]] = None,
|
|
||||||
poll_interval: float = 5.0,
|
|
||||||
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
|
|
||||||
max_retries: int = 3, # Max retries per individual API call
|
|
||||||
retry_delay: float = 1.0,
|
|
||||||
retry_backoff_factor: float = 2.0,
|
|
||||||
estimated_duration: Optional[float] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> None:
|
|
||||||
self.poll_endpoint = poll_endpoint
|
|
||||||
self.request = request
|
|
||||||
self.api_base: str = api_base or args.comfy_api_base
|
|
||||||
self.auth_token = auth_token
|
|
||||||
self.comfy_api_key = comfy_api_key
|
|
||||||
if auth_kwargs is not None:
|
|
||||||
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
|
|
||||||
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
|
|
||||||
self.poll_interval = poll_interval
|
|
||||||
self.max_poll_attempts = max_poll_attempts
|
|
||||||
self.max_retries = max_retries
|
|
||||||
self.retry_delay = retry_delay
|
|
||||||
self.retry_backoff_factor = retry_backoff_factor
|
|
||||||
self.estimated_duration = estimated_duration
|
|
||||||
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None))
|
|
||||||
self.progress_extractor = progress_extractor
|
|
||||||
self.result_url_extractor = result_url_extractor
|
|
||||||
self.node_id = node_id
|
|
||||||
self.completed_statuses = completed_statuses
|
|
||||||
self.failed_statuses = failed_statuses
|
|
||||||
self.final_response: Optional[R] = None
|
|
||||||
|
|
||||||
async def execute(self, client: Optional[ApiClient] = None) -> R:
|
|
||||||
owns_client = client is None
|
|
||||||
if owns_client:
|
|
||||||
client = ApiClient(
|
|
||||||
base_url=self.api_base,
|
|
||||||
auth_token=self.auth_token,
|
|
||||||
comfy_api_key=self.comfy_api_key,
|
|
||||||
max_retries=self.max_retries,
|
|
||||||
retry_delay=self.retry_delay,
|
|
||||||
retry_backoff_factor=self.retry_backoff_factor,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return await self._poll_until_complete(client)
|
|
||||||
finally:
|
|
||||||
if owns_client:
|
|
||||||
await client.close()
|
|
||||||
|
|
||||||
def _display_text_on_node(self, text: str):
|
|
||||||
if not self.node_id:
|
|
||||||
return
|
|
||||||
PromptServer.instance.send_progress_text(text, self.node_id)
|
|
||||||
|
|
||||||
def _display_time_progress_on_node(self, time_completed: int | float):
|
|
||||||
if not self.node_id:
|
|
||||||
return
|
|
||||||
if self.estimated_duration is not None:
|
|
||||||
remaining = max(0, int(self.estimated_duration) - time_completed)
|
|
||||||
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)"
|
|
||||||
else:
|
|
||||||
message = f"Task in progress: {time_completed}s"
|
|
||||||
self._display_text_on_node(message)
|
|
||||||
|
|
||||||
def _check_task_status(self, response: R) -> TaskStatus:
|
|
||||||
try:
|
|
||||||
status = self.status_extractor(response)
|
|
||||||
if status in self.completed_statuses:
|
|
||||||
return TaskStatus.COMPLETED
|
|
||||||
if status in self.failed_statuses:
|
|
||||||
return TaskStatus.FAILED
|
|
||||||
return TaskStatus.PENDING
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("Error extracting status: %s", e)
|
|
||||||
return TaskStatus.PENDING
|
|
||||||
|
|
||||||
async def _poll_until_complete(self, client: ApiClient) -> R:
|
|
||||||
"""Poll until the task is complete"""
|
|
||||||
consecutive_errors = 0
|
|
||||||
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
|
|
||||||
|
|
||||||
if self.progress_extractor:
|
|
||||||
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
|
|
||||||
|
|
||||||
status = TaskStatus.PENDING
|
|
||||||
for poll_count in range(1, self.max_poll_attempts + 1):
|
|
||||||
try:
|
|
||||||
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
|
|
||||||
|
|
||||||
request_dict = (
|
|
||||||
None if self.request is None else self.request.model_dump(exclude_none=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
if poll_count == 1:
|
|
||||||
logging.debug(
|
|
||||||
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
|
|
||||||
)
|
|
||||||
logging.debug(
|
|
||||||
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Query task status
|
|
||||||
resp = await client.request(
|
|
||||||
self.poll_endpoint.method.value,
|
|
||||||
self.poll_endpoint.path,
|
|
||||||
params=self.poll_endpoint.query_params,
|
|
||||||
data=request_dict,
|
|
||||||
)
|
|
||||||
consecutive_errors = 0 # reset on success
|
|
||||||
response_obj: R = self.poll_endpoint.response_model.model_validate(resp)
|
|
||||||
|
|
||||||
# Check if task is complete
|
|
||||||
status = self._check_task_status(response_obj)
|
|
||||||
logging.debug(f"[DEBUG] Task Status: {status}")
|
|
||||||
|
|
||||||
# If progress extractor is provided, extract progress
|
|
||||||
if self.progress_extractor:
|
|
||||||
new_progress = self.progress_extractor(response_obj)
|
|
||||||
if new_progress is not None:
|
|
||||||
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
|
|
||||||
|
|
||||||
if status == TaskStatus.COMPLETED:
|
|
||||||
message = "Task completed successfully"
|
|
||||||
if self.result_url_extractor:
|
|
||||||
result_url = self.result_url_extractor(response_obj)
|
|
||||||
if result_url:
|
|
||||||
message = f"Result URL: {result_url}"
|
|
||||||
logging.debug(f"[DEBUG] {message}")
|
|
||||||
self._display_text_on_node(message)
|
|
||||||
self.final_response = response_obj
|
|
||||||
if self.progress_extractor:
|
|
||||||
progress.update(100)
|
|
||||||
return self.final_response
|
|
||||||
if status == TaskStatus.FAILED:
|
|
||||||
message = f"Task failed: {json.dumps(resp)}"
|
|
||||||
logging.error(f"[DEBUG] {message}")
|
|
||||||
raise Exception(message)
|
|
||||||
logging.debug("[DEBUG] Task still pending, continuing to poll...")
|
|
||||||
# Task pending – wait
|
|
||||||
for i in range(int(self.poll_interval)):
|
|
||||||
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i)
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
|
|
||||||
except (LocalNetworkError, ApiServerError, NetworkError) as e:
|
|
||||||
consecutive_errors += 1
|
|
||||||
if consecutive_errors >= max_consecutive_errors:
|
|
||||||
raise Exception(
|
|
||||||
f"Polling aborted after {consecutive_errors} network errors: {str(e)}"
|
|
||||||
) from e
|
|
||||||
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e))
|
|
||||||
await asyncio.sleep(self.poll_interval)
|
|
||||||
except Exception as e:
|
|
||||||
# For other errors, increment count and potentially abort
|
|
||||||
consecutive_errors += 1
|
|
||||||
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
|
|
||||||
raise Exception(
|
|
||||||
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
logging.error(f"[DEBUG] Polling error: {str(e)}")
|
|
||||||
logging.warning(
|
|
||||||
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
|
|
||||||
f"Will retry in {self.poll_interval} seconds."
|
|
||||||
)
|
|
||||||
await asyncio.sleep(self.poll_interval)
|
|
||||||
|
|
||||||
# If we've exhausted all polling attempts
|
|
||||||
raise Exception(
|
|
||||||
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). "
|
|
||||||
"The operation may still be running on the server but is taking longer than expected."
|
|
||||||
)
|
|
||||||
@ -1,19 +1,22 @@
|
|||||||
from __future__ import annotations
|
from typing import Optional
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiImageConfig(BaseModel):
|
||||||
|
aspectRatio: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||||
responseModalities: Optional[List[str]] = None
|
responseModalities: Optional[list[str]] = None
|
||||||
|
imageConfig: Optional[GeminiImageConfig] = None
|
||||||
|
|
||||||
|
|
||||||
class GeminiImageGenerateContentRequest(BaseModel):
|
class GeminiImageGenerateContentRequest(BaseModel):
|
||||||
contents: List[GeminiContent]
|
contents: list[GeminiContent]
|
||||||
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
generationConfig: Optional[GeminiImageGenerationConfig] = None
|
||||||
safetySettings: Optional[List[GeminiSafetySetting]] = None
|
safetySettings: Optional[list[GeminiSafetySetting]] = None
|
||||||
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
systemInstruction: Optional[GeminiSystemInstructionContent] = None
|
||||||
tools: Optional[List[GeminiTool]] = None
|
tools: Optional[list[GeminiTool]] = None
|
||||||
videoMetadata: Optional[GeminiVideoMetadata] = None
|
videoMetadata: Optional[GeminiVideoMetadata] = None
|
||||||
|
|||||||
120
comfy_api_nodes/apis/minimax_api.py
Normal file
120
comfy_api_nodes/apis/minimax_api.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxBaseResponse(BaseModel):
|
||||||
|
status_code: int = Field(
|
||||||
|
...,
|
||||||
|
description='Status code. 0 indicates success, other values indicate errors.',
|
||||||
|
)
|
||||||
|
status_msg: str = Field(
|
||||||
|
..., description='Specific error details or success message.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
bytes: Optional[int] = Field(None, description='File size in bytes')
|
||||||
|
created_at: Optional[int] = Field(
|
||||||
|
None, description='Unix timestamp when the file was created, in seconds'
|
||||||
|
)
|
||||||
|
download_url: Optional[str] = Field(
|
||||||
|
None, description='The URL to download the video'
|
||||||
|
)
|
||||||
|
backup_download_url: Optional[str] = Field(
|
||||||
|
None, description='The backup URL to download the video'
|
||||||
|
)
|
||||||
|
|
||||||
|
file_id: Optional[int] = Field(None, description='Unique identifier for the file')
|
||||||
|
filename: Optional[str] = Field(None, description='The name of the file')
|
||||||
|
purpose: Optional[str] = Field(None, description='The purpose of using the file')
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxFileRetrieveResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
file: File
|
||||||
|
|
||||||
|
|
||||||
|
class MiniMaxModel(str, Enum):
|
||||||
|
T2V_01_Director = 'T2V-01-Director'
|
||||||
|
I2V_01_Director = 'I2V-01-Director'
|
||||||
|
S2V_01 = 'S2V-01'
|
||||||
|
I2V_01 = 'I2V-01'
|
||||||
|
I2V_01_live = 'I2V-01-live'
|
||||||
|
T2V_01 = 'T2V-01'
|
||||||
|
Hailuo_02 = 'MiniMax-Hailuo-02'
|
||||||
|
|
||||||
|
|
||||||
|
class Status6(str, Enum):
|
||||||
|
Queueing = 'Queueing'
|
||||||
|
Preparing = 'Preparing'
|
||||||
|
Processing = 'Processing'
|
||||||
|
Success = 'Success'
|
||||||
|
Fail = 'Fail'
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxTaskResultResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
file_id: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
|
||||||
|
)
|
||||||
|
status: Status6 = Field(
|
||||||
|
...,
|
||||||
|
description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
|
||||||
|
)
|
||||||
|
task_id: str = Field(..., description='The task ID being queried.')
|
||||||
|
|
||||||
|
|
||||||
|
class SubjectReferenceItem(BaseModel):
|
||||||
|
image: Optional[str] = Field(
|
||||||
|
None, description='URL or base64 encoding of the subject reference image.'
|
||||||
|
)
|
||||||
|
mask: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='URL or base64 encoding of the mask for the subject reference image.',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxVideoGenerationRequest(BaseModel):
|
||||||
|
callback_url: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Optional. URL to receive real-time status updates about the video generation task.',
|
||||||
|
)
|
||||||
|
first_frame_image: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
|
||||||
|
)
|
||||||
|
model: MiniMaxModel = Field(
|
||||||
|
...,
|
||||||
|
description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
|
||||||
|
)
|
||||||
|
prompt: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
|
||||||
|
max_length=2000,
|
||||||
|
)
|
||||||
|
prompt_optimizer: Optional[bool] = Field(
|
||||||
|
True,
|
||||||
|
description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
|
||||||
|
)
|
||||||
|
subject_reference: Optional[list[SubjectReferenceItem]] = Field(
|
||||||
|
None,
|
||||||
|
description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
|
||||||
|
)
|
||||||
|
duration: Optional[int] = Field(
|
||||||
|
None,
|
||||||
|
description="The length of the output video in seconds."
|
||||||
|
)
|
||||||
|
resolution: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxVideoGenerationResponse(BaseModel):
|
||||||
|
base_resp: MinimaxBaseResponse
|
||||||
|
task_id: str = Field(
|
||||||
|
..., description='The task ID for the asynchronous video generation task.'
|
||||||
|
)
|
||||||
100
comfy_api_nodes/apis/pika_api.py
Normal file
100
comfy_api_nodes/apis/pika_api.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Pikaffect(str, Enum):
|
||||||
|
Cake_ify = "Cake-ify"
|
||||||
|
Crumble = "Crumble"
|
||||||
|
Crush = "Crush"
|
||||||
|
Decapitate = "Decapitate"
|
||||||
|
Deflate = "Deflate"
|
||||||
|
Dissolve = "Dissolve"
|
||||||
|
Explode = "Explode"
|
||||||
|
Eye_pop = "Eye-pop"
|
||||||
|
Inflate = "Inflate"
|
||||||
|
Levitate = "Levitate"
|
||||||
|
Melt = "Melt"
|
||||||
|
Peel = "Peel"
|
||||||
|
Poke = "Poke"
|
||||||
|
Squish = "Squish"
|
||||||
|
Ta_da = "Ta-da"
|
||||||
|
Tear = "Tear"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(None, description='Aspect ratio (width / height)')
|
||||||
|
duration: Optional[int] = Field(5)
|
||||||
|
ingredientsMode: str = Field(...)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = Field('1080p')
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaGenerateResponse(BaseModel):
|
||||||
|
video_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
|
||||||
|
duration: Optional[int] = Field(None, ge=5, le=10)
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
|
||||||
|
aspectRatio: Optional[float] = Field(
|
||||||
|
1.7777777777777777,
|
||||||
|
description='Aspect ratio (width / height)',
|
||||||
|
ge=0.4,
|
||||||
|
le=2.5,
|
||||||
|
)
|
||||||
|
duration: Optional[int] = 5
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: str = Field(...)
|
||||||
|
resolution: Optional[str] = '1080p'
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
pikaffect: Optional[str] = None
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
|
||||||
|
negativePrompt: Optional[str] = Field(None)
|
||||||
|
promptText: Optional[str] = Field(None)
|
||||||
|
seed: Optional[int] = Field(None)
|
||||||
|
modifyRegionRoi: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class PikaStatusEnum(str, Enum):
|
||||||
|
queued = "queued"
|
||||||
|
started = "started"
|
||||||
|
finished = "finished"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class PikaVideoResponse(BaseModel):
|
||||||
|
id: str = Field(...)
|
||||||
|
progress: Optional[int] = Field(None)
|
||||||
|
status: PikaStatusEnum
|
||||||
|
url: Optional[str] = Field(None)
|
||||||
@ -1,13 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from comfy_api_nodes.apis import (
|
|
||||||
TripoModelVersion,
|
|
||||||
TripoTextureQuality,
|
|
||||||
)
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, List, Dict, Any, Union
|
from typing import Optional, List, Dict, Any, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, RootModel
|
from pydantic import BaseModel, Field, RootModel
|
||||||
|
|
||||||
|
class TripoModelVersion(str, Enum):
|
||||||
|
v2_5_20250123 = 'v2.5-20250123'
|
||||||
|
v2_0_20240919 = 'v2.0-20240919'
|
||||||
|
v1_4_20240625 = 'v1.4-20240625'
|
||||||
|
|
||||||
|
|
||||||
|
class TripoTextureQuality(str, Enum):
|
||||||
|
standard = 'standard'
|
||||||
|
detailed = 'detailed'
|
||||||
|
|
||||||
|
|
||||||
class TripoStyle(str, Enum):
|
class TripoStyle(str, Enum):
|
||||||
PERSON_TO_CARTOON = "person:person2cartoon"
|
PERSON_TO_CARTOON = "person:person2cartoon"
|
||||||
ANIMAL_VENOM = "animal:venom"
|
ANIMAL_VENOM = "animal:venom"
|
||||||
|
|||||||
111
comfy_api_nodes/apis/veo_api.py
Normal file
111
comfy_api_nodes/apis/veo_api.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Image2(BaseModel):
|
||||||
|
bytesBase64Encoded: str
|
||||||
|
gcsUri: Optional[str] = None
|
||||||
|
mimeType: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Image3(BaseModel):
|
||||||
|
bytesBase64Encoded: Optional[str] = None
|
||||||
|
gcsUri: str
|
||||||
|
mimeType: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Instance1(BaseModel):
|
||||||
|
image: Optional[Union[Image2, Image3]] = Field(
|
||||||
|
None, description='Optional image to guide video generation'
|
||||||
|
)
|
||||||
|
prompt: str = Field(..., description='Text description of the video')
|
||||||
|
|
||||||
|
|
||||||
|
class PersonGeneration1(str, Enum):
|
||||||
|
ALLOW = 'ALLOW'
|
||||||
|
BLOCK = 'BLOCK'
|
||||||
|
|
||||||
|
|
||||||
|
class Parameters1(BaseModel):
|
||||||
|
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||||
|
durationSeconds: Optional[int] = None
|
||||||
|
enhancePrompt: Optional[bool] = None
|
||||||
|
generateAudio: Optional[bool] = Field(
|
||||||
|
None,
|
||||||
|
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||||
|
)
|
||||||
|
negativePrompt: Optional[str] = None
|
||||||
|
personGeneration: Optional[PersonGeneration1] = None
|
||||||
|
sampleCount: Optional[int] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
storageUri: Optional[str] = Field(
|
||||||
|
None, description='Optional Cloud Storage URI to upload the video'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VeoGenVidRequest(BaseModel):
|
||||||
|
instances: Optional[list[Instance1]] = None
|
||||||
|
parameters: Optional[Parameters1] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VeoGenVidResponse(BaseModel):
|
||||||
|
name: str = Field(
|
||||||
|
...,
|
||||||
|
description='Operation resource name',
|
||||||
|
examples=[
|
||||||
|
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class VeoGenVidPollRequest(BaseModel):
|
||||||
|
operationName: str = Field(
|
||||||
|
...,
|
||||||
|
description='Full operation name (from predict response)',
|
||||||
|
examples=[
|
||||||
|
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Video(BaseModel):
|
||||||
|
bytesBase64Encoded: Optional[str] = Field(
|
||||||
|
None, description='Base64-encoded video content'
|
||||||
|
)
|
||||||
|
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
|
||||||
|
mimeType: Optional[str] = Field(None, description='Video MIME type')
|
||||||
|
|
||||||
|
|
||||||
|
class Error1(BaseModel):
|
||||||
|
code: Optional[int] = Field(None, description='Error code')
|
||||||
|
message: Optional[str] = Field(None, description='Error message')
|
||||||
|
|
||||||
|
|
||||||
|
class Response1(BaseModel):
|
||||||
|
field_type: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
alias='@type',
|
||||||
|
examples=[
|
||||||
|
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
|
||||||
|
],
|
||||||
|
)
|
||||||
|
raiMediaFilteredCount: Optional[int] = Field(
|
||||||
|
None, description='Count of media filtered by responsible AI policies'
|
||||||
|
)
|
||||||
|
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||||
|
None, description='Reasons why media was filtered by responsible AI policies'
|
||||||
|
)
|
||||||
|
videos: Optional[list[Video]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class VeoGenVidPollResponse(BaseModel):
|
||||||
|
done: Optional[bool] = None
|
||||||
|
error: Optional[Error1] = Field(
|
||||||
|
None, description='Error details if operation failed'
|
||||||
|
)
|
||||||
|
name: Optional[str] = None
|
||||||
|
response: Optional[Response1] = Field(
|
||||||
|
None, description='The actual prediction response if done is true'
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -2,45 +2,47 @@
|
|||||||
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
API Nodes for Gemini Multimodal LLM Usage via Remote API
|
||||||
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
import base64
|
import base64
|
||||||
from io import BytesIO
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Literal
|
from io import BytesIO
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
from server import PromptServer
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis import (
|
||||||
GeminiContent,
|
GeminiContent,
|
||||||
GeminiGenerateContentRequest,
|
GeminiGenerateContentRequest,
|
||||||
GeminiGenerateContentResponse,
|
GeminiGenerateContentResponse,
|
||||||
GeminiInlineData,
|
GeminiInlineData,
|
||||||
GeminiPart,
|
|
||||||
GeminiMimeType,
|
GeminiMimeType,
|
||||||
|
GeminiPart,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest
|
from comfy_api_nodes.apis.gemini_api import (
|
||||||
from comfy_api_nodes.apis.client import (
|
GeminiImageConfig,
|
||||||
|
GeminiImageGenerateContentRequest,
|
||||||
|
GeminiImageGenerationConfig,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
validate_string,
|
|
||||||
audio_to_base64_string,
|
audio_to_base64_string,
|
||||||
video_to_base64_string,
|
|
||||||
tensor_to_base64_string,
|
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
|
sync_op,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
validate_string,
|
||||||
|
video_to_base64_string,
|
||||||
)
|
)
|
||||||
from comfy_api.util import VideoContainer, VideoCodec
|
from server import PromptServer
|
||||||
|
|
||||||
|
|
||||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||||
@ -63,50 +65,7 @@ class GeminiImageModel(str, Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
gemini_2_5_flash_image_preview = "gemini-2.5-flash-image-preview"
|
||||||
|
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||||
|
|
||||||
def get_gemini_endpoint(
|
|
||||||
model: GeminiModel,
|
|
||||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
|
||||||
"""
|
|
||||||
Get the API endpoint for a given Gemini model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The Gemini model to use, either as enum or string value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiEndpoint configured for the specific Gemini model.
|
|
||||||
"""
|
|
||||||
if isinstance(model, str):
|
|
||||||
model = GeminiModel(model)
|
|
||||||
return ApiEndpoint(
|
|
||||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=GeminiGenerateContentRequest,
|
|
||||||
response_model=GeminiGenerateContentResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_gemini_image_endpoint(
|
|
||||||
model: GeminiImageModel,
|
|
||||||
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
|
|
||||||
"""
|
|
||||||
Get the API endpoint for a given Gemini model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The Gemini model to use, either as enum or string value.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ApiEndpoint configured for the specific Gemini model.
|
|
||||||
"""
|
|
||||||
if isinstance(model, str):
|
|
||||||
model = GeminiImageModel(model)
|
|
||||||
return ApiEndpoint(
|
|
||||||
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=GeminiImageGenerateContentRequest,
|
|
||||||
response_model=GeminiGenerateContentResponse,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||||
@ -121,9 +80,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
|||||||
"""
|
"""
|
||||||
image_parts: list[GeminiPart] = []
|
image_parts: list[GeminiPart] = []
|
||||||
for image_index in range(image_input.shape[0]):
|
for image_index in range(image_input.shape[0]):
|
||||||
image_as_b64 = tensor_to_base64_string(
|
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
||||||
image_input[image_index].unsqueeze(0)
|
|
||||||
)
|
|
||||||
image_parts.append(
|
image_parts.append(
|
||||||
GeminiPart(
|
GeminiPart(
|
||||||
inlineData=GeminiInlineData(
|
inlineData=GeminiInlineData(
|
||||||
@ -135,37 +92,7 @@ def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
|||||||
return image_parts
|
return image_parts
|
||||||
|
|
||||||
|
|
||||||
def create_text_part(text: str) -> GeminiPart:
|
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
|
||||||
"""
|
|
||||||
Create a text part for the Gemini API request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text content to include in the request.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A GeminiPart object with the text content.
|
|
||||||
"""
|
|
||||||
return GeminiPart(text=text)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parts_from_response(
|
|
||||||
response: GeminiGenerateContentResponse
|
|
||||||
) -> list[GeminiPart]:
|
|
||||||
"""
|
|
||||||
Extract all parts from the Gemini API response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: The API response from Gemini.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of response parts from the first candidate.
|
|
||||||
"""
|
|
||||||
return response.candidates[0].content.parts
|
|
||||||
|
|
||||||
|
|
||||||
def get_parts_by_type(
|
|
||||||
response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
|
|
||||||
) -> list[GeminiPart]:
|
|
||||||
"""
|
"""
|
||||||
Filter response parts by their type.
|
Filter response parts by their type.
|
||||||
|
|
||||||
@ -177,14 +104,10 @@ def get_parts_by_type(
|
|||||||
List of response parts matching the requested type.
|
List of response parts matching the requested type.
|
||||||
"""
|
"""
|
||||||
parts = []
|
parts = []
|
||||||
for part in get_parts_from_response(response):
|
for part in response.candidates[0].content.parts:
|
||||||
if part_type == "text" and hasattr(part, "text") and part.text:
|
if part_type == "text" and hasattr(part, "text") and part.text:
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
elif (
|
elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type:
|
||||||
hasattr(part, "inlineData")
|
|
||||||
and part.inlineData
|
|
||||||
and part.inlineData.mimeType == part_type
|
|
||||||
):
|
|
||||||
parts.append(part)
|
parts.append(part)
|
||||||
# Skip parts that don't match the requested type
|
# Skip parts that don't match the requested type
|
||||||
return parts
|
return parts
|
||||||
@ -212,11 +135,11 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
|||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
image_tensors.append(returned_image)
|
image_tensors.append(returned_image)
|
||||||
if len(image_tensors) == 0:
|
if len(image_tensors) == 0:
|
||||||
return torch.zeros((1,1024,1024,4))
|
return torch.zeros((1, 1024, 1024, 4))
|
||||||
return torch.cat(image_tensors, dim=0)
|
return torch.cat(image_tensors, dim=0)
|
||||||
|
|
||||||
|
|
||||||
class GeminiNode(ComfyNodeABC):
|
class GeminiNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Node to generate text responses from a Gemini model.
|
Node to generate text responses from a Gemini model.
|
||||||
|
|
||||||
@ -227,96 +150,79 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="GeminiNode",
|
||||||
"prompt": (
|
display_name="Google Gemini",
|
||||||
IO.STRING,
|
category="api node/text/Gemini",
|
||||||
{
|
description="Generate text responses with Google's Gemini AI model. "
|
||||||
"multiline": True,
|
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||||
"default": "",
|
"as context for generating more relevant and meaningful responses.",
|
||||||
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
|
inputs=[
|
||||||
},
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text inputs to the model, used to generate a response. "
|
||||||
|
"You can include detailed instructions, questions, or context for the model.",
|
||||||
),
|
),
|
||||||
"model": (
|
IO.Combo.Input(
|
||||||
IO.COMBO,
|
"model",
|
||||||
{
|
options=GeminiModel,
|
||||||
"tooltip": "The Gemini model to use for generating responses.",
|
default=GeminiModel.gemini_2_5_pro,
|
||||||
"options": [model.value for model in GeminiModel],
|
tooltip="The Gemini model to use for generating responses.",
|
||||||
"default": GeminiModel.gemini_2_5_pro.value,
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"seed": (
|
IO.Int.Input(
|
||||||
IO.INT,
|
"seed",
|
||||||
{
|
default=42,
|
||||||
"default": 42,
|
min=0,
|
||||||
"min": 0,
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
control_after_generate=True,
|
||||||
"control_after_generate": True,
|
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||||
},
|
"Also, changing the model or parameter settings, such as the temperature, "
|
||||||
|
"can cause variations in the response even when you use the same seed value. "
|
||||||
|
"By default, a random seed value is used.",
|
||||||
),
|
),
|
||||||
},
|
IO.Image.Input(
|
||||||
"optional": {
|
"images",
|
||||||
"images": (
|
optional=True,
|
||||||
IO.IMAGE,
|
tooltip="Optional image(s) to use as context for the model. "
|
||||||
{
|
"To include multiple images, you can use the Batch Images node.",
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"audio": (
|
IO.Audio.Input(
|
||||||
IO.AUDIO,
|
"audio",
|
||||||
{
|
optional=True,
|
||||||
"tooltip": "Optional audio to use as context for the model.",
|
tooltip="Optional audio to use as context for the model.",
|
||||||
"default": None,
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"video": (
|
IO.Video.Input(
|
||||||
IO.VIDEO,
|
"video",
|
||||||
{
|
optional=True,
|
||||||
"tooltip": "Optional video to use as context for the model.",
|
tooltip="Optional video to use as context for the model.",
|
||||||
"default": None,
|
|
||||||
},
|
|
||||||
),
|
),
|
||||||
"files": (
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
"GEMINI_INPUT_FILES",
|
"files",
|
||||||
{
|
optional=True,
|
||||||
"default": None,
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
"hidden": {
|
outputs=[
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
IO.String.Output(),
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
],
|
||||||
"unique_id": "UNIQUE_ID",
|
hidden=[
|
||||||
},
|
IO.Hidden.auth_token_comfy_org,
|
||||||
}
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
|
],
|
||||||
RETURN_TYPES = ("STRING",)
|
is_api_node=True,
|
||||||
FUNCTION = "api_call"
|
|
||||||
CATEGORY = "api node/text/Gemini"
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
|
|
||||||
"""
|
|
||||||
Convert video input to Gemini API compatible parts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video_input: Video tensor from ComfyUI.
|
|
||||||
**kwargs: Additional arguments to pass to the conversion function.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of GeminiPart objects containing the encoded video.
|
|
||||||
"""
|
|
||||||
|
|
||||||
base_64_string = video_to_base64_string(
|
|
||||||
video_input,
|
|
||||||
container_format=VideoContainer.MP4,
|
|
||||||
codec=VideoCodec.H264
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||||
|
"""Convert video input to Gemini API compatible parts."""
|
||||||
|
|
||||||
|
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
return [
|
return [
|
||||||
GeminiPart(
|
GeminiPart(
|
||||||
inlineData=GeminiInlineData(
|
inlineData=GeminiInlineData(
|
||||||
@ -326,7 +232,8 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
|
@classmethod
|
||||||
|
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||||
"""
|
"""
|
||||||
Convert audio input to Gemini API compatible parts.
|
Convert audio input to Gemini API compatible parts.
|
||||||
|
|
||||||
@ -339,10 +246,10 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
audio_parts: list[GeminiPart] = []
|
audio_parts: list[GeminiPart] = []
|
||||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||||
audio_at_index = {
|
audio_at_index = Input.Audio(
|
||||||
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
|
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||||
"sample_rate": audio_input["sample_rate"],
|
sample_rate=audio_input["sample_rate"],
|
||||||
}
|
)
|
||||||
# Convert to MP3 format for compatibility with Gemini API
|
# Convert to MP3 format for compatibility with Gemini API
|
||||||
audio_bytes = audio_to_base64_string(
|
audio_bytes = audio_to_base64_string(
|
||||||
audio_at_index,
|
audio_at_index,
|
||||||
@ -359,38 +266,38 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
return audio_parts
|
return audio_parts
|
||||||
|
|
||||||
async def api_call(
|
@classmethod
|
||||||
self,
|
async def execute(
|
||||||
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: GeminiModel,
|
model: str,
|
||||||
images: Optional[IO.IMAGE] = None,
|
seed: int,
|
||||||
audio: Optional[IO.AUDIO] = None,
|
images: Optional[torch.Tensor] = None,
|
||||||
video: Optional[IO.VIDEO] = None,
|
audio: Optional[Input.Audio] = None,
|
||||||
|
video: Optional[Input.Video] = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: Optional[list[GeminiPart]] = None,
|
||||||
unique_id: Optional[str] = None,
|
) -> IO.NodeOutput:
|
||||||
**kwargs,
|
|
||||||
) -> tuple[str]:
|
|
||||||
# Validate inputs
|
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
|
|
||||||
# Create parts list with text prompt as the first part
|
# Create parts list with text prompt as the first part
|
||||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
|
|
||||||
# Add other modal parts
|
# Add other modal parts
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_parts = create_image_parts(images)
|
image_parts = create_image_parts(images)
|
||||||
parts.extend(image_parts)
|
parts.extend(image_parts)
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
parts.extend(self.create_audio_parts(audio))
|
parts.extend(cls.create_audio_parts(audio))
|
||||||
if video is not None:
|
if video is not None:
|
||||||
parts.extend(self.create_video_parts(video))
|
parts.extend(cls.create_video_parts(video))
|
||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
# Create response
|
# Create response
|
||||||
response = await SynchronousOperation(
|
response = await sync_op(
|
||||||
endpoint=get_gemini_endpoint(model),
|
cls,
|
||||||
request=GeminiGenerateContentRequest(
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
|
data=GeminiGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(
|
GeminiContent(
|
||||||
role="user",
|
role="user",
|
||||||
@ -398,15 +305,15 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
response_model=GeminiGenerateContentResponse,
|
||||||
).execute()
|
)
|
||||||
|
|
||||||
# Get result output
|
# Get result output
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if unique_id and output_text:
|
if output_text:
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||||
render_spec = {
|
render_spec = {
|
||||||
"node_id": unique_id,
|
"node_id": cls.hidden.unique_id,
|
||||||
"component": "ChatHistoryWidget",
|
"component": "ChatHistoryWidget",
|
||||||
"props": {
|
"props": {
|
||||||
"history": json.dumps(
|
"history": json.dumps(
|
||||||
@ -426,10 +333,10 @@ class GeminiNode(ComfyNodeABC):
|
|||||||
render_spec,
|
render_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (output_text or "Empty response from Gemini model...",)
|
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||||
|
|
||||||
|
|
||||||
class GeminiInputFiles(ComfyNodeABC):
|
class GeminiInputFiles(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Loads and formats input files for use with the Gemini API.
|
Loads and formats input files for use with the Gemini API.
|
||||||
|
|
||||||
@ -440,7 +347,7 @@ class GeminiInputFiles(ComfyNodeABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def define_schema(cls):
|
||||||
"""
|
"""
|
||||||
For details about the supported file input types, see:
|
For details about the supported file input types, see:
|
||||||
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
@ -455,39 +362,37 @@ class GeminiInputFiles(ComfyNodeABC):
|
|||||||
]
|
]
|
||||||
input_files = sorted(input_files, key=lambda x: x.name)
|
input_files = sorted(input_files, key=lambda x: x.name)
|
||||||
input_files = [f.name for f in input_files]
|
input_files = [f.name for f in input_files]
|
||||||
return {
|
return IO.Schema(
|
||||||
"required": {
|
node_id="GeminiInputFiles",
|
||||||
"file": (
|
display_name="Gemini Input Files",
|
||||||
IO.COMBO,
|
category="api node/text/Gemini",
|
||||||
{
|
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||||
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
|
"The files will be read by the Gemini model when generating a response. "
|
||||||
"options": input_files,
|
"The contents of the text file count toward the token limit. "
|
||||||
"default": input_files[0] if input_files else None,
|
"🛈 TIP: Can be chained together with other Gemini Input File nodes.",
|
||||||
},
|
inputs=[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"file",
|
||||||
|
options=input_files,
|
||||||
|
default=input_files[0] if input_files else None,
|
||||||
|
tooltip="Input files to include as context for the model. "
|
||||||
|
"Only accepts text (.txt) and PDF (.pdf) files for now.",
|
||||||
),
|
),
|
||||||
},
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
"optional": {
|
|
||||||
"GEMINI_INPUT_FILES": (
|
|
||||||
"GEMINI_INPUT_FILES",
|
"GEMINI_INPUT_FILES",
|
||||||
{
|
optional=True,
|
||||||
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
|
tooltip="An optional additional file(s) to batch together with the file loaded from this node. "
|
||||||
"default": None,
|
"Allows chaining of input files so that a single message can include multiple input files.",
|
||||||
},
|
|
||||||
),
|
),
|
||||||
},
|
],
|
||||||
}
|
outputs=[
|
||||||
|
IO.Custom("GEMINI_INPUT_FILES").Output(),
|
||||||
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
|
],
|
||||||
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
|
|
||||||
FUNCTION = "prepare_files"
|
|
||||||
CATEGORY = "api node/text/Gemini"
|
|
||||||
|
|
||||||
def create_file_part(self, file_path: str) -> GeminiPart:
|
|
||||||
mime_type = (
|
|
||||||
GeminiMimeType.application_pdf
|
|
||||||
if file_path.endswith(".pdf")
|
|
||||||
else GeminiMimeType.text_plain
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_file_part(cls, file_path: str) -> GeminiPart:
|
||||||
|
mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain
|
||||||
# Use base64 string directly, not the data URI
|
# Use base64 string directly, not the data URI
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
file_content = f.read()
|
file_content = f.read()
|
||||||
@ -500,143 +405,127 @@ class GeminiInputFiles(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_files(
|
|
||||||
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
|
|
||||||
) -> tuple[list[GeminiPart]]:
|
|
||||||
"""
|
|
||||||
Loads and formats input files for Gemini API.
|
|
||||||
"""
|
|
||||||
file_path = folder_paths.get_annotated_filepath(file)
|
|
||||||
input_file_content = self.create_file_part(file_path)
|
|
||||||
files = [input_file_content] + GEMINI_INPUT_FILES
|
|
||||||
return (files,)
|
|
||||||
|
|
||||||
|
|
||||||
class GeminiImage(ComfyNodeABC):
|
|
||||||
"""
|
|
||||||
Node to generate text and image responses from a Gemini model.
|
|
||||||
|
|
||||||
This node allows users to interact with Google's Gemini AI models, providing
|
|
||||||
multimodal inputs (text, images, files) to generate coherent
|
|
||||||
text and image responses. The node works with the latest Gemini models, handling the
|
|
||||||
API communication and response parsing.
|
|
||||||
"""
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypeDict:
|
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
|
||||||
return {
|
"""Loads and formats input files for Gemini API."""
|
||||||
"required": {
|
if GEMINI_INPUT_FILES is None:
|
||||||
"prompt": (
|
GEMINI_INPUT_FILES = []
|
||||||
IO.STRING,
|
file_path = folder_paths.get_annotated_filepath(file)
|
||||||
{
|
input_file_content = cls.create_file_part(file_path)
|
||||||
"multiline": True,
|
return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES)
|
||||||
"default": "",
|
|
||||||
"tooltip": "Text prompt for generation",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"model": (
|
|
||||||
IO.COMBO,
|
|
||||||
{
|
|
||||||
"tooltip": "The Gemini model to use for generating responses.",
|
|
||||||
"options": [model.value for model in GeminiImageModel],
|
|
||||||
"default": GeminiImageModel.gemini_2_5_flash_image_preview.value,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"seed": (
|
|
||||||
IO.INT,
|
|
||||||
{
|
|
||||||
"default": 42,
|
|
||||||
"min": 0,
|
|
||||||
"max": 0xFFFFFFFFFFFFFFFF,
|
|
||||||
"control_after_generate": True,
|
|
||||||
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
"optional": {
|
|
||||||
"images": (
|
|
||||||
IO.IMAGE,
|
|
||||||
{
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
"files": (
|
|
||||||
"GEMINI_INPUT_FILES",
|
|
||||||
{
|
|
||||||
"default": None,
|
|
||||||
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
# TODO: later we can add this parameter later
|
|
||||||
# "n": (
|
|
||||||
# IO.INT,
|
|
||||||
# {
|
|
||||||
# "default": 1,
|
|
||||||
# "min": 1,
|
|
||||||
# "max": 8,
|
|
||||||
# "step": 1,
|
|
||||||
# "display": "number",
|
|
||||||
# "tooltip": "How many images to generate",
|
|
||||||
# },
|
|
||||||
# ),
|
|
||||||
},
|
|
||||||
"hidden": {
|
|
||||||
"auth_token": "AUTH_TOKEN_COMFY_ORG",
|
|
||||||
"comfy_api_key": "API_KEY_COMFY_ORG",
|
|
||||||
"unique_id": "UNIQUE_ID",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
RETURN_TYPES = (IO.IMAGE, IO.STRING)
|
|
||||||
FUNCTION = "api_call"
|
|
||||||
CATEGORY = "api node/image/Gemini"
|
|
||||||
DESCRIPTION = "Edit images synchronously via Google API."
|
|
||||||
API_NODE = True
|
|
||||||
|
|
||||||
async def api_call(
|
class GeminiImage(IO.ComfyNode):
|
||||||
self,
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="GeminiImageNode",
|
||||||
|
display_name="Google Gemini Image",
|
||||||
|
category="api node/image/Gemini",
|
||||||
|
description="Edit images synchronously via Google API.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Text prompt for generation",
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=GeminiImageModel,
|
||||||
|
default=GeminiImageModel.gemini_2_5_flash_image,
|
||||||
|
tooltip="The Gemini model to use for generating responses.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=42,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="When seed is fixed to a specific value, the model makes a best effort to provide "
|
||||||
|
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||||
|
"Also, changing the model or parameter settings, such as the temperature, "
|
||||||
|
"can cause variations in the response even when you use the same seed value. "
|
||||||
|
"By default, a random seed value is used.",
|
||||||
|
),
|
||||||
|
IO.Image.Input(
|
||||||
|
"images",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional image(s) to use as context for the model. "
|
||||||
|
"To include multiple images, you can use the Batch Images node.",
|
||||||
|
),
|
||||||
|
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||||
|
"files",
|
||||||
|
optional=True,
|
||||||
|
tooltip="Optional file(s) to use as context for the model. "
|
||||||
|
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||||
|
default="auto",
|
||||||
|
tooltip="Defaults to matching the output image size to that of your input image, "
|
||||||
|
"or otherwise generates 1:1 squares.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
IO.String.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: GeminiImageModel,
|
model: str,
|
||||||
images: Optional[IO.IMAGE] = None,
|
seed: int,
|
||||||
|
images: Optional[torch.Tensor] = None,
|
||||||
files: Optional[list[GeminiPart]] = None,
|
files: Optional[list[GeminiPart]] = None,
|
||||||
n=1,
|
aspect_ratio: str = "auto",
|
||||||
unique_id: Optional[str] = None,
|
) -> IO.NodeOutput:
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# Validate inputs
|
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
# Create parts list with text prompt as the first part
|
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||||
parts: list[GeminiPart] = [create_text_part(prompt)]
|
|
||||||
|
if not aspect_ratio:
|
||||||
|
aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December
|
||||||
|
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||||
|
|
||||||
# Add other modal parts
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
image_parts = create_image_parts(images)
|
image_parts = create_image_parts(images)
|
||||||
parts.extend(image_parts)
|
parts.extend(image_parts)
|
||||||
if files is not None:
|
if files is not None:
|
||||||
parts.extend(files)
|
parts.extend(files)
|
||||||
|
|
||||||
response = await SynchronousOperation(
|
response = await sync_op(
|
||||||
endpoint=get_gemini_image_endpoint(model),
|
cls,
|
||||||
request=GeminiImageGenerateContentRequest(
|
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||||
|
data=GeminiImageGenerateContentRequest(
|
||||||
contents=[
|
contents=[
|
||||||
GeminiContent(
|
GeminiContent(role="user", parts=parts),
|
||||||
role="user",
|
|
||||||
parts=parts,
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
generationConfig=GeminiImageGenerationConfig(
|
generationConfig=GeminiImageGenerationConfig(
|
||||||
responseModalities=["TEXT","IMAGE"]
|
responseModalities=["TEXT", "IMAGE"],
|
||||||
)
|
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=kwargs,
|
response_model=GeminiGenerateContentResponse,
|
||||||
).execute()
|
)
|
||||||
|
|
||||||
output_image = get_image_from_response(response)
|
output_image = get_image_from_response(response)
|
||||||
output_text = get_text_from_response(response)
|
output_text = get_text_from_response(response)
|
||||||
if unique_id and output_text:
|
if output_text:
|
||||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||||
render_spec = {
|
render_spec = {
|
||||||
"node_id": unique_id,
|
"node_id": cls.hidden.unique_id,
|
||||||
"component": "ChatHistoryWidget",
|
"component": "ChatHistoryWidget",
|
||||||
"props": {
|
"props": {
|
||||||
"history": json.dumps(
|
"history": json.dumps(
|
||||||
@ -657,17 +546,18 @@ class GeminiImage(ComfyNodeABC):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_text = output_text or "Empty response from Gemini model..."
|
output_text = output_text or "Empty response from Gemini model..."
|
||||||
return (output_image, output_text,)
|
return IO.NodeOutput(output_image, output_text)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class GeminiExtension(ComfyExtension):
|
||||||
"GeminiNode": GeminiNode,
|
@override
|
||||||
"GeminiImageNode": GeminiImage,
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
"GeminiInputFiles": GeminiInputFiles,
|
return [
|
||||||
}
|
GeminiNode,
|
||||||
|
GeminiImage,
|
||||||
|
GeminiInputFiles,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
|
||||||
"GeminiNode": "Google Gemini",
|
async def comfy_entrypoint() -> GeminiExtension:
|
||||||
"GeminiImageNode": "Google Gemini Image",
|
return GeminiExtension()
|
||||||
"GeminiInputFiles": "Gemini Input Files",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -11,19 +11,13 @@ from comfy_api_nodes.apis import (
|
|||||||
IdeogramV3Request,
|
IdeogramV3Request,
|
||||||
IdeogramV3EditRequest,
|
IdeogramV3EditRequest,
|
||||||
)
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
)
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_bytesio,
|
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
|
download_url_as_bytesio,
|
||||||
resize_mask_to_image,
|
resize_mask_to_image,
|
||||||
|
sync_op,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
V1_V1_RES_MAP = {
|
V1_V1_RES_MAP = {
|
||||||
"Auto":"AUTO",
|
"Auto":"AUTO",
|
||||||
@ -220,7 +214,7 @@ async def download_and_process_images(image_urls):
|
|||||||
|
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
# Using functions from apinode_utils.py to handle downloading and processing
|
# Using functions from apinode_utils.py to handle downloading and processing
|
||||||
image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO
|
image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO
|
||||||
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
|
||||||
image_tensors.append(img_tensor)
|
image_tensors.append(img_tensor)
|
||||||
|
|
||||||
@ -233,89 +227,76 @@ async def download_and_process_images(image_urls):
|
|||||||
return stacked_tensors
|
return stacked_tensors
|
||||||
|
|
||||||
|
|
||||||
def display_image_urls_on_node(image_urls, node_id):
|
class IdeogramV1(IO.ComfyNode):
|
||||||
if node_id and image_urls:
|
|
||||||
if len(image_urls) == 1:
|
|
||||||
PromptServer.instance.send_progress_text(
|
|
||||||
f"Generated Image URL:\n{image_urls[0]}", node_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
urls_text = "Generated Image URLs:\n" + "\n".join(
|
|
||||||
f"{i+1}. {url}" for i, url in enumerate(image_urls)
|
|
||||||
)
|
|
||||||
PromptServer.instance.send_progress_text(urls_text, node_id)
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV1(comfy_io.ComfyNode):
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="IdeogramV1",
|
node_id="IdeogramV1",
|
||||||
display_name="Ideogram V1",
|
display_name="Ideogram V1",
|
||||||
category="api node/image/Ideogram",
|
category="api node/image/Ideogram",
|
||||||
description="Generates images using the Ideogram V1 model.",
|
description="Generates images using the Ideogram V1 model.",
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the image generation",
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"turbo",
|
"turbo",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=list(V1_V2_RATIO_MAP.keys()),
|
options=list(V1_V2_RATIO_MAP.keys()),
|
||||||
default="1:1",
|
default="1:1",
|
||||||
tooltip="The aspect ratio for image generation.",
|
tooltip="The aspect ratio for image generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"magic_prompt_option",
|
"magic_prompt_option",
|
||||||
options=["AUTO", "ON", "OFF"],
|
options=["AUTO", "ON", "OFF"],
|
||||||
default="AUTO",
|
default="AUTO",
|
||||||
tooltip="Determine if MagicPrompt should be used in generation",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Description of what to exclude from the image",
|
tooltip="Description of what to exclude from the image",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"num_images",
|
"num_images",
|
||||||
default=1,
|
default=1,
|
||||||
min=1,
|
min=1,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -334,77 +315,63 @@ class IdeogramV1(comfy_io.ComfyNode):
|
|||||||
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
|
||||||
model = "V_1_TURBO" if turbo else "V_1"
|
model = "V_1_TURBO" if turbo else "V_1"
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
response_model=IdeogramGenerateResponse,
|
||||||
operation = SynchronousOperation(
|
data=IdeogramGenerateRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
num_images=num_images,
|
num_images=num_images,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV2(comfy_io.ComfyNode):
|
class IdeogramV2(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="IdeogramV2",
|
node_id="IdeogramV2",
|
||||||
display_name="Ideogram V2",
|
display_name="Ideogram V2",
|
||||||
category="api node/image/Ideogram",
|
category="api node/image/Ideogram",
|
||||||
description="Generates images using the Ideogram V2 model.",
|
description="Generates images using the Ideogram V2 model.",
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the image generation",
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"turbo",
|
"turbo",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
tooltip="Whether to use turbo mode (faster generation, potentially lower quality)",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=list(V1_V2_RATIO_MAP.keys()),
|
options=list(V1_V2_RATIO_MAP.keys()),
|
||||||
default="1:1",
|
default="1:1",
|
||||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=list(V1_V1_RES_MAP.keys()),
|
options=list(V1_V1_RES_MAP.keys()),
|
||||||
default="Auto",
|
default="Auto",
|
||||||
@ -412,44 +379,44 @@ class IdeogramV2(comfy_io.ComfyNode):
|
|||||||
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
"If not set to AUTO, this overrides the aspect_ratio setting.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"magic_prompt_option",
|
"magic_prompt_option",
|
||||||
options=["AUTO", "ON", "OFF"],
|
options=["AUTO", "ON", "OFF"],
|
||||||
default="AUTO",
|
default="AUTO",
|
||||||
tooltip="Determine if MagicPrompt should be used in generation",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"style_type",
|
"style_type",
|
||||||
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
options=["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
|
||||||
default="NONE",
|
default="NONE",
|
||||||
tooltip="Style type for generation (V2 only)",
|
tooltip="Style type for generation (V2 only)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Description of what to exclude from the image",
|
tooltip="Description of what to exclude from the image",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"num_images",
|
"num_images",
|
||||||
default=1,
|
default=1,
|
||||||
min=1,
|
min=1,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
#"color_palette": (
|
#"color_palette": (
|
||||||
@ -462,12 +429,12 @@ class IdeogramV2(comfy_io.ComfyNode):
|
|||||||
#),
|
#),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -500,18 +467,11 @@ class IdeogramV2(comfy_io.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
|
||||||
|
|
||||||
auth = {
|
response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"),
|
||||||
}
|
response_model=IdeogramGenerateResponse,
|
||||||
operation = SynchronousOperation(
|
data=IdeogramGenerateRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/ideogram/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=IdeogramGenerateRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=IdeogramGenerateRequest(
|
|
||||||
image_request=ImageRequest(
|
image_request=ImageRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
@ -519,36 +479,28 @@ class IdeogramV2(comfy_io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
aspect_ratio=final_aspect_ratio,
|
aspect_ratio=final_aspect_ratio,
|
||||||
resolution=final_resolution,
|
resolution=final_resolution,
|
||||||
magic_prompt_option=(
|
magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None),
|
||||||
magic_prompt_option if magic_prompt_option != "AUTO" else None
|
|
||||||
),
|
|
||||||
style_type=style_type if style_type != "NONE" else None,
|
style_type=style_type if style_type != "NONE" else None,
|
||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
color_palette=color_palette if color_palette else None,
|
color_palette=color_palette if color_palette else None,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramV3(comfy_io.ComfyNode):
|
class IdeogramV3(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="IdeogramV3",
|
node_id="IdeogramV3",
|
||||||
display_name="Ideogram V3",
|
display_name="Ideogram V3",
|
||||||
category="api node/image/Ideogram",
|
category="api node/image/Ideogram",
|
||||||
@ -556,30 +508,30 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the image generation or editing",
|
tooltip="Prompt for the image generation or editing",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Optional reference image for image editing.",
|
tooltip="Optional reference image for image editing.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Mask.Input(
|
IO.Mask.Input(
|
||||||
"mask",
|
"mask",
|
||||||
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
tooltip="Optional mask for inpainting (white areas will be replaced)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=list(V3_RATIO_MAP.keys()),
|
options=list(V3_RATIO_MAP.keys()),
|
||||||
default="1:1",
|
default="1:1",
|
||||||
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
tooltip="The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=V3_RESOLUTIONS,
|
options=V3_RESOLUTIONS,
|
||||||
default="Auto",
|
default="Auto",
|
||||||
@ -587,57 +539,57 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
"If not set to Auto, this overrides the aspect_ratio setting.",
|
"If not set to Auto, this overrides the aspect_ratio setting.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"magic_prompt_option",
|
"magic_prompt_option",
|
||||||
options=["AUTO", "ON", "OFF"],
|
options=["AUTO", "ON", "OFF"],
|
||||||
default="AUTO",
|
default="AUTO",
|
||||||
tooltip="Determine if MagicPrompt should be used in generation",
|
tooltip="Determine if MagicPrompt should be used in generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"num_images",
|
"num_images",
|
||||||
default=1,
|
default=1,
|
||||||
min=1,
|
min=1,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"rendering_speed",
|
"rendering_speed",
|
||||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||||
default="DEFAULT",
|
default="DEFAULT",
|
||||||
tooltip="Controls the trade-off between generation speed and quality",
|
tooltip="Controls the trade-off between generation speed and quality",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"character_image",
|
"character_image",
|
||||||
tooltip="Image to use as character reference.",
|
tooltip="Image to use as character reference.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Mask.Input(
|
IO.Mask.Input(
|
||||||
"character_mask",
|
"character_mask",
|
||||||
tooltip="Optional mask for character reference image.",
|
tooltip="Optional mask for character reference image.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -656,10 +608,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
character_image=None,
|
character_image=None,
|
||||||
character_mask=None,
|
character_mask=None,
|
||||||
):
|
):
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
if rendering_speed == "BALANCED": # for backward compatibility
|
if rendering_speed == "BALANCED": # for backward compatibility
|
||||||
rendering_speed = "DEFAULT"
|
rendering_speed = "DEFAULT"
|
||||||
|
|
||||||
@ -694,9 +642,6 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
|
|
||||||
# Check if both image and mask are provided for editing mode
|
# Check if both image and mask are provided for editing mode
|
||||||
if image is not None and mask is not None:
|
if image is not None and mask is not None:
|
||||||
# Edit mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/edit"
|
|
||||||
|
|
||||||
# Process image and mask
|
# Process image and mask
|
||||||
input_tensor = image.squeeze().cpu()
|
input_tensor = image.squeeze().cpu()
|
||||||
# Resize mask to match image dimension
|
# Resize mask to match image dimension
|
||||||
@ -749,27 +694,20 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
if character_mask_binary:
|
if character_mask_binary:
|
||||||
files["character_mask_binary"] = character_mask_binary
|
files["character_mask_binary"] = character_mask_binary
|
||||||
|
|
||||||
# Execute the operation for edit mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"),
|
||||||
path=path,
|
response_model=IdeogramGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=edit_request,
|
||||||
request_model=IdeogramV3EditRequest,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=edit_request,
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif image is not None or mask is not None:
|
elif image is not None or mask is not None:
|
||||||
# If only one of image or mask is provided, raise an error
|
# If only one of image or mask is provided, raise an error
|
||||||
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
raise Exception("Ideogram V3 image editing requires both an image AND a mask")
|
||||||
else:
|
else:
|
||||||
# Generation mode
|
|
||||||
path = "/proxy/ideogram/ideogram-v3/generate"
|
|
||||||
|
|
||||||
# Create generation request
|
# Create generation request
|
||||||
gen_request = IdeogramV3Request(
|
gen_request = IdeogramV3Request(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -800,43 +738,34 @@ class IdeogramV3(comfy_io.ComfyNode):
|
|||||||
if files:
|
if files:
|
||||||
gen_request.style_type = "AUTO"
|
gen_request.style_type = "AUTO"
|
||||||
|
|
||||||
# Execute the operation for generation mode
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"),
|
||||||
path=path,
|
response_model=IdeogramGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=gen_request,
|
||||||
request_model=IdeogramV3Request,
|
|
||||||
response_model=IdeogramGenerateResponse,
|
|
||||||
),
|
|
||||||
request=gen_request,
|
|
||||||
files=files if files else None,
|
files=files if files else None,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
max_retries=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the operation and process response
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if not response.data or len(response.data) == 0:
|
if not response.data or len(response.data) == 0:
|
||||||
raise Exception("No images were generated in the response")
|
raise Exception("No images were generated in the response")
|
||||||
|
|
||||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||||
|
|
||||||
if not image_urls:
|
if not image_urls:
|
||||||
raise Exception("No image URLs were generated in the response")
|
raise Exception("No image URLs were generated in the response")
|
||||||
|
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||||
display_image_urls_on_node(image_urls, cls.hidden.unique_id)
|
|
||||||
return comfy_io.NodeOutput(await download_and_process_images(image_urls))
|
|
||||||
|
|
||||||
|
|
||||||
class IdeogramExtension(ComfyExtension):
|
class IdeogramExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
IdeogramV1,
|
IdeogramV1,
|
||||||
IdeogramV2,
|
IdeogramV2,
|
||||||
IdeogramV3,
|
IdeogramV3,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> IdeogramExtension:
|
async def comfy_entrypoint() -> IdeogramExtension:
|
||||||
return IdeogramExtension()
|
return IdeogramExtension()
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
199
comfy_api_nodes/nodes_ltxv.py
Normal file
199
comfy_api_nodes/nodes_ltxv.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
get_number_of_images,
|
||||||
|
sync_op_raw,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS_MAP = {
|
||||||
|
"LTX-2 (Pro)": "ltx-2-pro",
|
||||||
|
"LTX-2 (Fast)": "ltx-2-fast",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ExecuteTaskRequest(BaseModel):
|
||||||
|
prompt: str = Field(...)
|
||||||
|
model: str = Field(...)
|
||||||
|
duration: int = Field(...)
|
||||||
|
resolution: str = Field(...)
|
||||||
|
fps: Optional[int] = Field(25)
|
||||||
|
generate_audio: Optional[bool] = Field(True)
|
||||||
|
image_uri: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToVideoNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LtxvApiTextToVideo",
|
||||||
|
display_name="LTXV Text To Video",
|
||||||
|
category="api node/video/LTXV",
|
||||||
|
description="Professional-quality videos with customizable duration and resolution.",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"1920x1080",
|
||||||
|
"2560x1440",
|
||||||
|
"3840x2160",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=False,
|
||||||
|
optional=True,
|
||||||
|
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
resolution: str,
|
||||||
|
fps: int = 25,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
|
response = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
|
||||||
|
data=ExecuteTaskRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
model=MODELS_MAP[model],
|
||||||
|
duration=duration,
|
||||||
|
resolution=resolution,
|
||||||
|
fps=fps,
|
||||||
|
generate_audio=generate_audio,
|
||||||
|
),
|
||||||
|
as_binary=True,
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToVideoNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LtxvApiImageToVideo",
|
||||||
|
display_name="LTXV Image To Video",
|
||||||
|
category="api node/video/LTXV",
|
||||||
|
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||||
|
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
),
|
||||||
|
IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=[
|
||||||
|
"1920x1080",
|
||||||
|
"2560x1440",
|
||||||
|
"3840x2160",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.Combo.Input("fps", options=[25, 50], default=25),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"generate_audio",
|
||||||
|
default=False,
|
||||||
|
optional=True,
|
||||||
|
tooltip="When true, the generated video will include AI-generated audio matching the scene.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image: torch.Tensor,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
duration: int,
|
||||||
|
resolution: str,
|
||||||
|
fps: int = 25,
|
||||||
|
generate_audio: bool = False,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, min_length=1, max_length=10000)
|
||||||
|
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
|
||||||
|
raise ValueError(
|
||||||
|
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
|
||||||
|
)
|
||||||
|
if get_number_of_images(image) != 1:
|
||||||
|
raise ValueError("Currently only one input image is supported.")
|
||||||
|
response = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"),
|
||||||
|
data=ExecuteTaskRequest(
|
||||||
|
image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0],
|
||||||
|
prompt=prompt,
|
||||||
|
model=MODELS_MAP[model],
|
||||||
|
duration=duration,
|
||||||
|
resolution=resolution,
|
||||||
|
fps=fps,
|
||||||
|
generate_audio=generate_audio,
|
||||||
|
),
|
||||||
|
as_binary=True,
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||||
|
|
||||||
|
|
||||||
|
class LtxvApiExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
TextToVideoNode,
|
||||||
|
ImageToVideoNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LtxvApiExtension:
|
||||||
|
return LtxvApiExtension()
|
||||||
@ -1,75 +1,57 @@
|
|||||||
from __future__ import annotations
|
|
||||||
from inspect import cleandoc
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.luma_api import (
|
from comfy_api_nodes.apis.luma_api import (
|
||||||
LumaImageModel,
|
|
||||||
LumaVideoModel,
|
|
||||||
LumaVideoOutputResolution,
|
|
||||||
LumaVideoModelOutputDuration,
|
|
||||||
LumaAspectRatio,
|
LumaAspectRatio,
|
||||||
LumaState,
|
|
||||||
LumaImageGenerationRequest,
|
|
||||||
LumaGenerationRequest,
|
|
||||||
LumaGeneration,
|
|
||||||
LumaCharacterRef,
|
LumaCharacterRef,
|
||||||
LumaModifyImageRef,
|
LumaConceptChain,
|
||||||
|
LumaGeneration,
|
||||||
|
LumaGenerationRequest,
|
||||||
|
LumaImageGenerationRequest,
|
||||||
LumaImageIdentity,
|
LumaImageIdentity,
|
||||||
|
LumaImageModel,
|
||||||
|
LumaImageReference,
|
||||||
|
LumaIO,
|
||||||
|
LumaKeyframes,
|
||||||
|
LumaModifyImageRef,
|
||||||
LumaReference,
|
LumaReference,
|
||||||
LumaReferenceChain,
|
LumaReferenceChain,
|
||||||
LumaImageReference,
|
LumaVideoModel,
|
||||||
LumaKeyframes,
|
LumaVideoModelOutputDuration,
|
||||||
LumaConceptChain,
|
LumaVideoOutputResolution,
|
||||||
LumaIO,
|
|
||||||
get_luma_concepts,
|
get_luma_concepts,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_image_tensor,
|
||||||
SynchronousOperation,
|
download_url_to_video_output,
|
||||||
PollingOperation,
|
poll_op,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
process_image_response,
|
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import torch
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
LUMA_T2V_AVERAGE_DURATION = 105
|
LUMA_T2V_AVERAGE_DURATION = 105
|
||||||
LUMA_I2V_AVERAGE_DURATION = 100
|
LUMA_I2V_AVERAGE_DURATION = 100
|
||||||
|
|
||||||
def image_result_url_extractor(response: LumaGeneration):
|
|
||||||
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
|
|
||||||
|
|
||||||
def video_result_url_extractor(response: LumaGeneration):
|
|
||||||
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
|
|
||||||
|
|
||||||
class LumaReferenceNode(comfy_io.ComfyNode):
|
|
||||||
"""
|
|
||||||
Holds an image and weight for use with Luma Generate Image node.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
class LumaReferenceNode(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaReferenceNode",
|
node_id="LumaReferenceNode",
|
||||||
display_name="Luma Reference",
|
display_name="Luma Reference",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Image to use as reference.",
|
tooltip="Image to use as reference.",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"weight",
|
"weight",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
@ -77,72 +59,56 @@ class LumaReferenceNode(comfy_io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Weight of image reference.",
|
tooltip="Weight of image reference.",
|
||||||
),
|
),
|
||||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||||
"luma_ref",
|
"luma_ref",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")],
|
||||||
hidden=[
|
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
|
||||||
comfy_io.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(
|
def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput:
|
||||||
cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
|
|
||||||
) -> comfy_io.NodeOutput:
|
|
||||||
if luma_ref is not None:
|
if luma_ref is not None:
|
||||||
luma_ref = luma_ref.clone()
|
luma_ref = luma_ref.clone()
|
||||||
else:
|
else:
|
||||||
luma_ref = LumaReferenceChain()
|
luma_ref = LumaReferenceChain()
|
||||||
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
|
||||||
return comfy_io.NodeOutput(luma_ref)
|
return IO.NodeOutput(luma_ref)
|
||||||
|
|
||||||
|
|
||||||
class LumaConceptsNode(comfy_io.ComfyNode):
|
class LumaConceptsNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaConceptsNode",
|
node_id="LumaConceptsNode",
|
||||||
display_name="Luma Concepts",
|
display_name="Luma Concepts",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"concept1",
|
"concept1",
|
||||||
options=get_luma_concepts(include_none=True),
|
options=get_luma_concepts(include_none=True),
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"concept2",
|
"concept2",
|
||||||
options=get_luma_concepts(include_none=True),
|
options=get_luma_concepts(include_none=True),
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"concept3",
|
"concept3",
|
||||||
options=get_luma_concepts(include_none=True),
|
options=get_luma_concepts(include_none=True),
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"concept4",
|
"concept4",
|
||||||
options=get_luma_concepts(include_none=True),
|
options=get_luma_concepts(include_none=True),
|
||||||
),
|
),
|
||||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||||
"luma_concepts",
|
"luma_concepts",
|
||||||
tooltip="Optional Camera Concepts to add to the ones chosen here.",
|
tooltip="Optional Camera Concepts to add to the ones chosen here.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")],
|
||||||
hidden=[
|
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
|
||||||
comfy_io.Hidden.unique_id,
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -153,42 +119,38 @@ class LumaConceptsNode(comfy_io.ComfyNode):
|
|||||||
concept3: str,
|
concept3: str,
|
||||||
concept4: str,
|
concept4: str,
|
||||||
luma_concepts: LumaConceptChain = None,
|
luma_concepts: LumaConceptChain = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
|
||||||
if luma_concepts is not None:
|
if luma_concepts is not None:
|
||||||
chain = luma_concepts.clone_and_merge(chain)
|
chain = luma_concepts.clone_and_merge(chain)
|
||||||
return comfy_io.NodeOutput(chain)
|
return IO.NodeOutput(chain)
|
||||||
|
|
||||||
|
|
||||||
class LumaImageGenerationNode(comfy_io.ComfyNode):
|
class LumaImageGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates images synchronously based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageNode",
|
node_id="LumaImageNode",
|
||||||
display_name="Luma Text to Image",
|
display_name="Luma Text to Image",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the image generation",
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in LumaImageModel],
|
options=LumaImageModel,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[ratio.value for ratio in LumaAspectRatio],
|
options=LumaAspectRatio,
|
||||||
default=LumaAspectRatio.ratio_16_9,
|
default=LumaAspectRatio.ratio_16_9,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -196,7 +158,7 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"style_image_weight",
|
"style_image_weight",
|
||||||
default=1.0,
|
default=1.0,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
@ -204,27 +166,27 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Weight of style image. Ignored if no style_image provided.",
|
tooltip="Weight of style image. Ignored if no style_image provided.",
|
||||||
),
|
),
|
||||||
comfy_io.Custom(LumaIO.LUMA_REF).Input(
|
IO.Custom(LumaIO.LUMA_REF).Input(
|
||||||
"image_luma_ref",
|
"image_luma_ref",
|
||||||
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
|
tooltip="Luma Reference node connection to influence generation with input images; up to 4 images can be considered.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"style_image",
|
"style_image",
|
||||||
tooltip="Style reference image; only 1 image will be used.",
|
tooltip="Style reference image; only 1 image will be used.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"character_image",
|
"character_image",
|
||||||
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
|
tooltip="Character reference images; can be a batch of multiple, up to 4 images can be considered.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -237,45 +199,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
seed,
|
seed,
|
||||||
style_image_weight: float,
|
style_image_weight: float,
|
||||||
image_luma_ref: LumaReferenceChain = None,
|
image_luma_ref: Optional[LumaReferenceChain] = None,
|
||||||
style_image: torch.Tensor = None,
|
style_image: Optional[torch.Tensor] = None,
|
||||||
character_image: torch.Tensor = None,
|
character_image: Optional[torch.Tensor] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=3)
|
validate_string(prompt, strip_whitespace=True, min_length=3)
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
# handle image_luma_ref
|
# handle image_luma_ref
|
||||||
api_image_ref = None
|
api_image_ref = None
|
||||||
if image_luma_ref is not None:
|
if image_luma_ref is not None:
|
||||||
api_image_ref = await cls._convert_luma_refs(
|
api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4)
|
||||||
image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
# handle style_luma_ref
|
# handle style_luma_ref
|
||||||
api_style_ref = None
|
api_style_ref = None
|
||||||
if style_image is not None:
|
if style_image is not None:
|
||||||
api_style_ref = await cls._convert_style_image(
|
api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight)
|
||||||
style_image, weight=style_image_weight, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
# handle character_ref images
|
# handle character_ref images
|
||||||
character_ref = None
|
character_ref = None
|
||||||
if character_image is not None:
|
if character_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4)
|
||||||
character_image, max_images=4, auth_kwargs=auth_kwargs,
|
character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls))
|
||||||
)
|
|
||||||
character_ref = LumaCharacterRef(
|
|
||||||
identity0=LumaImageIdentity(images=download_urls)
|
|
||||||
)
|
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/luma/generations/image",
|
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=LumaGeneration,
|
||||||
request_model=LumaImageGenerationRequest,
|
data=LumaImageGenerationRequest(
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
request=LumaImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -283,41 +230,21 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
|||||||
style_ref=api_style_ref,
|
style_ref=api_style_ref,
|
||||||
character_ref=character_ref,
|
character_ref=character_ref,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
operation = PollingOperation(
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
poll_endpoint=ApiEndpoint(
|
response_model=LumaGeneration,
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=image_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.image) as img_response:
|
|
||||||
img = process_image_response(await img_response.content.read())
|
|
||||||
return comfy_io.NodeOutput(img)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_luma_refs(
|
async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int):
|
||||||
cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
|
|
||||||
):
|
|
||||||
luma_urls = []
|
luma_urls = []
|
||||||
ref_count = 0
|
ref_count = 0
|
||||||
for ref in luma_ref.refs:
|
for ref in luma_ref.refs:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1)
|
||||||
ref.image, max_images=1, auth_kwargs=auth_kwargs
|
|
||||||
)
|
|
||||||
luma_urls.append(download_urls[0])
|
luma_urls.append(download_urls[0])
|
||||||
ref_count += 1
|
ref_count += 1
|
||||||
if ref_count >= max_refs:
|
if ref_count >= max_refs:
|
||||||
@ -325,38 +252,30 @@ class LumaImageGenerationNode(comfy_io.ComfyNode):
|
|||||||
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_style_image(
|
async def _convert_style_image(cls, style_image: torch.Tensor, weight: float):
|
||||||
cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
|
chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight))
|
||||||
):
|
return await cls._convert_luma_refs(chain, max_refs=1)
|
||||||
chain = LumaReferenceChain(
|
|
||||||
first_ref=LumaReference(image=style_image, weight=weight)
|
|
||||||
)
|
|
||||||
return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LumaImageModifyNode(comfy_io.ComfyNode):
|
class LumaImageModifyNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Modifies images synchronously based on prompt and aspect ratio.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageModifyNode",
|
node_id="LumaImageModifyNode",
|
||||||
display_name="Luma Image to Image",
|
display_name="Luma Image to Image",
|
||||||
category="api node/image/Luma",
|
category="api node/image/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the image generation",
|
tooltip="Prompt for the image generation",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"image_weight",
|
"image_weight",
|
||||||
default=0.1,
|
default=0.1,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
@ -364,11 +283,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
|
tooltip="Weight of the image; the closer to 1.0, the less the image will be modified.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in LumaImageModel],
|
options=LumaImageModel,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -377,11 +296,11 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
|||||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -394,99 +313,68 @@ class LumaImageModifyNode(comfy_io.ComfyNode):
|
|||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
image_weight: float,
|
image_weight: float,
|
||||||
seed,
|
seed,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth_kwargs = {
|
download_urls = await upload_images_to_comfyapi(cls, image, max_images=1)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
# first, upload image
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
|
||||||
image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
image_url = download_urls[0]
|
image_url = download_urls[0]
|
||||||
# next, make Luma call with download url provided
|
response_api = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/luma/generations/image", method="POST"),
|
||||||
path="/proxy/luma/generations/image",
|
response_model=LumaGeneration,
|
||||||
method=HttpMethod.POST,
|
data=LumaImageGenerationRequest(
|
||||||
request_model=LumaImageGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
request=LumaImageGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
modify_image_ref=LumaModifyImageRef(
|
modify_image_ref=LumaModifyImageRef(
|
||||||
url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
|
url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
operation = PollingOperation(
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
poll_endpoint=ApiEndpoint(
|
response_model=LumaGeneration,
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=image_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.image) as img_response:
|
|
||||||
img = process_image_response(await img_response.content.read())
|
|
||||||
return comfy_io.NodeOutput(img)
|
|
||||||
|
|
||||||
|
|
||||||
class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaVideoNode",
|
node_id="LumaVideoNode",
|
||||||
display_name="Luma Text to Video",
|
display_name="Luma Text to Video",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the video generation",
|
tooltip="Prompt for the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in LumaVideoModel],
|
options=LumaVideoModel,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[ratio.value for ratio in LumaAspectRatio],
|
options=LumaAspectRatio,
|
||||||
default=LumaAspectRatio.ratio_16_9,
|
default=LumaAspectRatio.ratio_16_9,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
options=LumaVideoOutputResolution,
|
||||||
default=LumaVideoOutputResolution.res_540p,
|
default=LumaVideoOutputResolution.res_540p,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
options=LumaVideoModelOutputDuration,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"loop",
|
"loop",
|
||||||
default=False,
|
default=False,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -494,17 +382,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||||
"luma_concepts",
|
"luma_concepts",
|
||||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -519,24 +407,17 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
duration: str,
|
duration: str,
|
||||||
loop: bool,
|
loop: bool,
|
||||||
seed,
|
seed,
|
||||||
luma_concepts: LumaConceptChain = None,
|
luma_concepts: Optional[LumaConceptChain] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False, min_length=3)
|
validate_string(prompt, strip_whitespace=False, min_length=3)
|
||||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
|
||||||
auth_kwargs = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||||
}
|
response_model=LumaGeneration,
|
||||||
operation = SynchronousOperation(
|
data=LumaGenerationRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/luma/generations",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=LumaGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
request=LumaGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
@ -545,77 +426,55 @@ class LumaTextToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
loop=loop,
|
loop=loop,
|
||||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
if cls.hidden.unique_id:
|
ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
response_model=LumaGeneration,
|
||||||
|
|
||||||
operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=video_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.video) as vid_response:
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on prompt, input images, and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="LumaImageToVideoNode",
|
node_id="LumaImageToVideoNode",
|
||||||
display_name="Luma Image to Video",
|
display_name="Luma Image to Video",
|
||||||
category="api node/video/Luma",
|
category="api node/video/Luma",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the video generation",
|
tooltip="Prompt for the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in LumaVideoModel],
|
options=LumaVideoModel,
|
||||||
),
|
),
|
||||||
# comfy_io.Combo.Input(
|
# IO.Combo.Input(
|
||||||
# "aspect_ratio",
|
# "aspect_ratio",
|
||||||
# options=[ratio.value for ratio in LumaAspectRatio],
|
# options=[ratio.value for ratio in LumaAspectRatio],
|
||||||
# default=LumaAspectRatio.ratio_16_9,
|
# default=LumaAspectRatio.ratio_16_9,
|
||||||
# ),
|
# ),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[resolution.value for resolution in LumaVideoOutputResolution],
|
options=LumaVideoOutputResolution,
|
||||||
default=LumaVideoOutputResolution.res_540p,
|
default=LumaVideoOutputResolution.res_540p,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
options=[dur.value for dur in LumaVideoModelOutputDuration],
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"loop",
|
"loop",
|
||||||
default=False,
|
default=False,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -623,27 +482,27 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
tooltip="Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"first_image",
|
"first_image",
|
||||||
tooltip="First frame of generated video.",
|
tooltip="First frame of generated video.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"last_image",
|
"last_image",
|
||||||
tooltip="Last frame of generated video.",
|
tooltip="Last frame of generated video.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
IO.Custom(LumaIO.LUMA_CONCEPTS).Input(
|
||||||
"luma_concepts",
|
"luma_concepts",
|
||||||
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
)
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -660,27 +519,17 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
first_image: torch.Tensor = None,
|
first_image: torch.Tensor = None,
|
||||||
last_image: torch.Tensor = None,
|
last_image: torch.Tensor = None,
|
||||||
luma_concepts: LumaConceptChain = None,
|
luma_concepts: LumaConceptChain = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if first_image is None and last_image is None:
|
if first_image is None and last_image is None:
|
||||||
raise Exception(
|
raise Exception("At least one of first_image and last_image requires an input.")
|
||||||
"At least one of first_image and last_image requires an input."
|
keyframes = await cls._convert_to_keyframes(first_image, last_image)
|
||||||
)
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs)
|
|
||||||
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
duration = duration if model != LumaVideoModel.ray_1_6 else None
|
||||||
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
resolution = resolution if model != LumaVideoModel.ray_1_6 else None
|
||||||
|
response_api = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/luma/generations", method="POST"),
|
||||||
path="/proxy/luma/generations",
|
response_model=LumaGeneration,
|
||||||
method=HttpMethod.POST,
|
data=LumaGenerationRequest(
|
||||||
request_model=LumaGenerationRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
request=LumaGenerationRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
model=model,
|
model=model,
|
||||||
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
|
||||||
@ -690,61 +539,38 @@ class LumaImageToVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
keyframes=keyframes,
|
keyframes=keyframes,
|
||||||
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
concepts=luma_concepts.create_api_model() if luma_concepts else None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_api: LumaGeneration = await operation.execute()
|
response_poll = await poll_op(
|
||||||
|
cls,
|
||||||
if cls.hidden.unique_id:
|
poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"),
|
||||||
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id)
|
response_model=LumaGeneration,
|
||||||
|
|
||||||
operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/luma/generations/{response_api.id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=LumaGeneration,
|
|
||||||
),
|
|
||||||
completed_statuses=[LumaState.completed],
|
|
||||||
failed_statuses=[LumaState.failed],
|
|
||||||
status_extractor=lambda x: x.state,
|
status_extractor=lambda x: x.state,
|
||||||
result_url_extractor=video_result_url_extractor,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.assets.video) as vid_response:
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _convert_to_keyframes(
|
async def _convert_to_keyframes(
|
||||||
cls,
|
cls,
|
||||||
first_image: torch.Tensor = None,
|
first_image: torch.Tensor = None,
|
||||||
last_image: torch.Tensor = None,
|
last_image: torch.Tensor = None,
|
||||||
auth_kwargs: Optional[dict[str,str]] = None,
|
|
||||||
):
|
):
|
||||||
if first_image is None and last_image is None:
|
if first_image is None and last_image is None:
|
||||||
return None
|
return None
|
||||||
frame0 = None
|
frame0 = None
|
||||||
frame1 = None
|
frame1 = None
|
||||||
if first_image is not None:
|
if first_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1)
|
||||||
first_image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
frame0 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
if last_image is not None:
|
if last_image is not None:
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1)
|
||||||
last_image, max_images=1, auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
frame1 = LumaImageReference(type="image", url=download_urls[0])
|
||||||
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
return LumaKeyframes(frame0=frame0, frame1=frame1)
|
||||||
|
|
||||||
|
|
||||||
class LumaExtension(ComfyExtension):
|
class LumaExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
LumaImageGenerationNode,
|
LumaImageGenerationNode,
|
||||||
LumaImageModifyNode,
|
LumaImageModifyNode,
|
||||||
|
|||||||
@ -1,71 +1,57 @@
|
|||||||
from inspect import cleandoc
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api_nodes.apis.minimax_api import (
|
||||||
|
MinimaxFileRetrieveResponse,
|
||||||
|
MiniMaxModel,
|
||||||
|
MinimaxTaskResultResponse,
|
||||||
MinimaxVideoGenerationRequest,
|
MinimaxVideoGenerationRequest,
|
||||||
MinimaxVideoGenerationResponse,
|
MinimaxVideoGenerationResponse,
|
||||||
MinimaxFileRetrieveResponse,
|
|
||||||
MinimaxTaskResultResponse,
|
|
||||||
SubjectReferenceItem,
|
SubjectReferenceItem,
|
||||||
MiniMaxModel,
|
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_bytesio,
|
|
||||||
upload_images_to_comfyapi,
|
upload_images_to_comfyapi,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from server import PromptServer
|
|
||||||
|
|
||||||
|
|
||||||
I2V_AVERAGE_DURATION = 114
|
I2V_AVERAGE_DURATION = 114
|
||||||
T2V_AVERAGE_DURATION = 234
|
T2V_AVERAGE_DURATION = 234
|
||||||
|
|
||||||
|
|
||||||
async def _generate_mm_video(
|
async def _generate_mm_video(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
*,
|
*,
|
||||||
auth: dict[str, str],
|
|
||||||
node_id: str,
|
|
||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
model: str,
|
model: str,
|
||||||
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
image: Optional[torch.Tensor] = None, # used for ImageToVideo
|
||||||
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
subject: Optional[torch.Tensor] = None, # used for SubjectToVideo
|
||||||
average_duration: Optional[int] = None,
|
average_duration: Optional[int] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if image is None:
|
if image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
# upload image, if passed in
|
|
||||||
image_url = None
|
image_url = None
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0]
|
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0]
|
||||||
|
|
||||||
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
# TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
|
||||||
subject_reference = None
|
subject_reference = None
|
||||||
if subject is not None:
|
if subject is not None:
|
||||||
subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0]
|
subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0]
|
||||||
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
subject_reference = [SubjectReferenceItem(image=subject_url)]
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
video_generate_operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||||
path="/proxy/minimax/video_generation",
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
method=HttpMethod.POST,
|
data=MinimaxVideoGenerationRequest(
|
||||||
request_model=MinimaxVideoGenerationRequest,
|
|
||||||
response_model=MinimaxVideoGenerationResponse,
|
|
||||||
),
|
|
||||||
request=MinimaxVideoGenerationRequest(
|
|
||||||
model=MiniMaxModel(model),
|
model=MiniMaxModel(model),
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
callback_url=None,
|
callback_url=None,
|
||||||
@ -73,95 +59,64 @@ async def _generate_mm_video(
|
|||||||
subject_reference=subject_reference,
|
subject_reference=subject_reference,
|
||||||
prompt_optimizer=None,
|
prompt_optimizer=None,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
task_id = response.task_id
|
task_id = response.task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
video_generate_operation = PollingOperation(
|
task_result = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/query/video_generation",
|
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||||
method=HttpMethod.GET,
|
response_model=MinimaxTaskResultResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxTaskResultResponse,
|
|
||||||
query_params={"task_id": task_id},
|
|
||||||
),
|
|
||||||
completed_statuses=["Success"],
|
|
||||||
failed_statuses=["Fail"],
|
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=average_duration,
|
estimated_duration=average_duration,
|
||||||
node_id=node_id,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
file_id = task_result.file_id
|
file_id = task_result.file_id
|
||||||
if file_id is None:
|
if file_id is None:
|
||||||
raise Exception("Request was not successful. Missing file ID.")
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
file_retrieve_operation = SynchronousOperation(
|
file_result = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/files/retrieve",
|
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||||
method=HttpMethod.GET,
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxFileRetrieveResponse,
|
|
||||||
query_params={"file_id": int(file_id)},
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
|
||||||
|
|
||||||
file_url = file_result.file.download_url
|
file_url = file_result.file.download_url
|
||||||
if file_url is None:
|
if file_url is None:
|
||||||
raise Exception(
|
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
if file_result.file.backup_download_url:
|
||||||
)
|
try:
|
||||||
logging.info("Generated video URL: %s", file_url)
|
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||||
if node_id:
|
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
return IO.NodeOutput(
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||||
else:
|
)
|
||||||
message = f"Result URL: {file_url}"
|
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||||
PromptServer.instance.send_progress_text(message, node_id)
|
|
||||||
|
|
||||||
# Download and return as VideoFromFile
|
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
|
||||||
if video_io is None:
|
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise Exception(error_msg)
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
|
||||||
|
|
||||||
|
|
||||||
class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxTextToVideoNode",
|
node_id="MinimaxTextToVideoNode",
|
||||||
display_name="MiniMax Text to Video",
|
display_name="MiniMax Text to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt to guide the video generation",
|
tooltip="Text prompt to guide the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["T2V-01", "T2V-01-Director"],
|
options=["T2V-01", "T2V-01-Director"],
|
||||||
default="T2V-01",
|
default="T2V-01",
|
||||||
tooltip="Model to use for video generation",
|
tooltip="Model to use for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -172,11 +127,11 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -187,13 +142,9 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
model: str = "T2V-01",
|
model: str = "T2V-01",
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -203,36 +154,32 @@ class MinimaxTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxImageToVideoNode",
|
node_id="MinimaxImageToVideoNode",
|
||||||
display_name="MiniMax Image to Video",
|
display_name="MiniMax Image to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Image to use as first frame of video generation",
|
tooltip="Image to use as first frame of video generation",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt to guide the video generation",
|
tooltip="Text prompt to guide the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
|
options=["I2V-01-Director", "I2V-01", "I2V-01-live"],
|
||||||
default="I2V-01",
|
default="I2V-01",
|
||||||
tooltip="Model to use for video generation",
|
tooltip="Model to use for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -243,11 +190,11 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -259,13 +206,9 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
model: str = "I2V-01",
|
model: str = "I2V-01",
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -275,36 +218,32 @@ class MinimaxImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxSubjectToVideoNode",
|
node_id="MinimaxSubjectToVideoNode",
|
||||||
display_name="MiniMax Subject to Video",
|
display_name="MiniMax Subject to Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"subject",
|
"subject",
|
||||||
tooltip="Image of subject to reference for video generation",
|
tooltip="Image of subject to reference for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt to guide the video generation",
|
tooltip="Text prompt to guide the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["S2V-01"],
|
options=["S2V-01"],
|
||||||
default="S2V-01",
|
default="S2V-01",
|
||||||
tooltip="Model to use for video generation",
|
tooltip="Model to use for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -315,11 +254,11 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -331,13 +270,9 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
|||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
model: str = "S2V-01",
|
model: str = "S2V-01",
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
return await _generate_mm_video(
|
return await _generate_mm_video(
|
||||||
auth={
|
cls,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
prompt_text=prompt_text,
|
prompt_text=prompt_text,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model=model,
|
model=model,
|
||||||
@ -347,24 +282,22 @@ class MinimaxSubjectToVideoNode(comfy_io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||||
"""Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model."""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MinimaxHailuoVideoNode",
|
node_id="MinimaxHailuoVideoNode",
|
||||||
display_name="MiniMax Hailuo Video",
|
display_name="MiniMax Hailuo Video",
|
||||||
category="api node/video/MiniMax",
|
category="api node/video/MiniMax",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt_text",
|
"prompt_text",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt to guide the video generation.",
|
tooltip="Text prompt to guide the video generation.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -374,25 +307,25 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
tooltip="The random seed used for creating the noise.",
|
tooltip="The random seed used for creating the noise.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"first_frame_image",
|
"first_frame_image",
|
||||||
tooltip="Optional image to use as the first frame to generate a video.",
|
tooltip="Optional image to use as the first frame to generate a video.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_optimizer",
|
"prompt_optimizer",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Optimize prompt to improve generation quality when needed.",
|
tooltip="Optimize prompt to improve generation quality when needed.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[6, 10],
|
options=[6, 10],
|
||||||
default=6,
|
default=6,
|
||||||
tooltip="The length of the output video in seconds.",
|
tooltip="The length of the output video in seconds.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=["768P", "1080P"],
|
options=["768P", "1080P"],
|
||||||
default="768P",
|
default="768P",
|
||||||
@ -400,11 +333,11 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -419,11 +352,7 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
duration: int = 6,
|
duration: int = 6,
|
||||||
resolution: str = "768P",
|
resolution: str = "768P",
|
||||||
model: str = "MiniMax-Hailuo-02",
|
model: str = "MiniMax-Hailuo-02",
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
if first_frame_image is None:
|
if first_frame_image is None:
|
||||||
validate_string(prompt_text, field_name="prompt_text")
|
validate_string(prompt_text, field_name="prompt_text")
|
||||||
|
|
||||||
@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
# upload image, if passed in
|
# upload image, if passed in
|
||||||
image_url = None
|
image_url = None
|
||||||
if first_frame_image is not None:
|
if first_frame_image is not None:
|
||||||
image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0]
|
image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0]
|
||||||
|
|
||||||
video_generate_operation = SynchronousOperation(
|
response = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/video_generation",
|
ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=MinimaxVideoGenerationResponse,
|
||||||
request_model=MinimaxVideoGenerationRequest,
|
data=MinimaxVideoGenerationRequest(
|
||||||
response_model=MinimaxVideoGenerationResponse,
|
|
||||||
),
|
|
||||||
request=MinimaxVideoGenerationRequest(
|
|
||||||
model=MiniMaxModel(model),
|
model=MiniMaxModel(model),
|
||||||
prompt=prompt_text,
|
prompt=prompt_text,
|
||||||
callback_url=None,
|
callback_url=None,
|
||||||
@ -453,72 +379,47 @@ class MinimaxHailuoVideoNode(comfy_io.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
task_id = response.task_id
|
task_id = response.task_id
|
||||||
if not task_id:
|
if not task_id:
|
||||||
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
raise Exception(f"MiniMax generation failed: {response.base_resp}")
|
||||||
|
|
||||||
average_duration = 120 if resolution == "768P" else 240
|
average_duration = 120 if resolution == "768P" else 240
|
||||||
video_generate_operation = PollingOperation(
|
task_result = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/query/video_generation",
|
ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}),
|
||||||
method=HttpMethod.GET,
|
response_model=MinimaxTaskResultResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxTaskResultResponse,
|
|
||||||
query_params={"task_id": task_id},
|
|
||||||
),
|
|
||||||
completed_statuses=["Success"],
|
|
||||||
failed_statuses=["Fail"],
|
|
||||||
status_extractor=lambda x: x.status.value,
|
status_extractor=lambda x: x.status.value,
|
||||||
estimated_duration=average_duration,
|
estimated_duration=average_duration,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_result = await video_generate_operation.execute()
|
|
||||||
|
|
||||||
file_id = task_result.file_id
|
file_id = task_result.file_id
|
||||||
if file_id is None:
|
if file_id is None:
|
||||||
raise Exception("Request was not successful. Missing file ID.")
|
raise Exception("Request was not successful. Missing file ID.")
|
||||||
file_retrieve_operation = SynchronousOperation(
|
file_result = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/minimax/files/retrieve",
|
ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}),
|
||||||
method=HttpMethod.GET,
|
response_model=MinimaxFileRetrieveResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MinimaxFileRetrieveResponse,
|
|
||||||
query_params={"file_id": int(file_id)},
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
file_result = await file_retrieve_operation.execute()
|
|
||||||
|
|
||||||
file_url = file_result.file.download_url
|
file_url = file_result.file.download_url
|
||||||
if file_url is None:
|
if file_url is None:
|
||||||
raise Exception(
|
raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}")
|
||||||
f"No video was found in the response. Full response: {file_result.model_dump()}"
|
|
||||||
)
|
|
||||||
logging.info(f"Generated video URL: {file_url}")
|
|
||||||
if cls.hidden.unique_id:
|
|
||||||
if hasattr(file_result.file, "backup_download_url"):
|
|
||||||
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
|
|
||||||
else:
|
|
||||||
message = f"Result URL: {file_url}"
|
|
||||||
PromptServer.instance.send_progress_text(message, cls.hidden.unique_id)
|
|
||||||
|
|
||||||
video_io = await download_url_to_bytesio(file_url)
|
if file_result.file.backup_download_url:
|
||||||
if video_io is None:
|
try:
|
||||||
error_msg = f"Failed to download video from {file_url}"
|
return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2))
|
||||||
logging.error(error_msg)
|
except Exception: # if we have a second URL to retrieve the result, try again using that one
|
||||||
raise Exception(error_msg)
|
return IO.NodeOutput(
|
||||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3)
|
||||||
|
)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(file_url))
|
||||||
|
|
||||||
|
|
||||||
class MinimaxExtension(ComfyExtension):
|
class MinimaxExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
MinimaxTextToVideoNode,
|
MinimaxTextToVideoNode,
|
||||||
MinimaxImageToVideoNode,
|
MinimaxImageToVideoNode,
|
||||||
|
|||||||
@ -1,33 +1,30 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any, Callable, Optional, TypeVar
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions
|
|
||||||
|
|
||||||
from comfy_api_nodes.apis import (
|
|
||||||
MoonvalleyTextToVideoRequest,
|
|
||||||
MoonvalleyTextToVideoInferenceParams,
|
|
||||||
MoonvalleyVideoToVideoInferenceParams,
|
|
||||||
MoonvalleyVideoToVideoRequest,
|
|
||||||
MoonvalleyPromptResponse,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
download_url_to_video_output,
|
|
||||||
upload_images_to_comfyapi,
|
|
||||||
upload_video_to_comfyapi,
|
|
||||||
)
|
|
||||||
|
|
||||||
from comfy_api.input import VideoInput
|
from comfy_api.input import VideoInput
|
||||||
from comfy_api.latest import ComfyExtension, InputImpl, io as comfy_io
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
import av
|
from comfy_api_nodes.apis import (
|
||||||
import io
|
MoonvalleyPromptResponse,
|
||||||
|
MoonvalleyTextToVideoInferenceParams,
|
||||||
|
MoonvalleyTextToVideoRequest,
|
||||||
|
MoonvalleyVideoToVideoInferenceParams,
|
||||||
|
MoonvalleyVideoToVideoRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
download_url_to_video_output,
|
||||||
|
poll_op,
|
||||||
|
sync_op,
|
||||||
|
trim_video,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
|
validate_container_format_is_mp4,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
|
||||||
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads"
|
||||||
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts"
|
||||||
@ -50,13 +47,6 @@ MAX_VID_HEIGHT = 10000
|
|||||||
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing
|
||||||
|
|
||||||
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class MoonvalleyApiError(Exception):
|
|
||||||
"""Base exception for Moonvalley API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool:
|
||||||
@ -68,64 +58,7 @@ def validate_task_creation_response(response) -> None:
|
|||||||
if not is_valid_task_creation_response(response):
|
if not is_valid_task_creation_response(response):
|
||||||
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise MoonvalleyApiError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
def get_video_from_response(response):
|
|
||||||
video = response.output_url
|
|
||||||
logging.info(
|
|
||||||
"Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video
|
|
||||||
)
|
|
||||||
return video
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(response) -> Optional[str]:
|
|
||||||
"""Returns the first video url from the Moonvalley video generation task result.
|
|
||||||
Will not raise an error if the response is not valid.
|
|
||||||
"""
|
|
||||||
if response:
|
|
||||||
return str(get_video_from_response(response))
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_until_finished(
|
|
||||||
auth_kwargs: dict[str, str],
|
|
||||||
api_endpoint: ApiEndpoint[Any, R],
|
|
||||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> R:
|
|
||||||
"""Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response."""
|
|
||||||
return await PollingOperation(
|
|
||||||
poll_endpoint=api_endpoint,
|
|
||||||
completed_statuses=[
|
|
||||||
"completed",
|
|
||||||
],
|
|
||||||
max_poll_attempts=240, # 64 minutes with 16s interval
|
|
||||||
poll_interval=16.0,
|
|
||||||
failed_statuses=["error"],
|
|
||||||
status_extractor=lambda response: (
|
|
||||||
response.status if response and response.status else None
|
|
||||||
),
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=result_url_extractor,
|
|
||||||
node_id=node_id,
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
|
|
||||||
def validate_prompts(
|
|
||||||
prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH
|
|
||||||
):
|
|
||||||
"""Verifies that the prompt isn't empty and that neither prompt is too long."""
|
|
||||||
if not prompt:
|
|
||||||
raise ValueError("Positive prompt is empty")
|
|
||||||
if len(prompt) > max_length:
|
|
||||||
raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
|
|
||||||
if negative_prompt and len(negative_prompt) > max_length:
|
|
||||||
raise ValueError(
|
|
||||||
f"Negative prompt is too long: {len(negative_prompt)} characters"
|
|
||||||
)
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||||
@ -144,7 +77,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
|||||||
"""
|
"""
|
||||||
width, height = _get_video_dimensions(video)
|
width, height = _get_video_dimensions(video)
|
||||||
_validate_video_dimensions(width, height)
|
_validate_video_dimensions(width, height)
|
||||||
_validate_container_format(video)
|
validate_container_format_is_mp4(video)
|
||||||
|
|
||||||
return _validate_and_trim_duration(video)
|
return _validate_and_trim_duration(video)
|
||||||
|
|
||||||
@ -169,21 +102,8 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (width, height) not in supported_resolutions:
|
if (width, height) not in supported_resolutions:
|
||||||
supported_list = ", ".join(
|
supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)])
|
||||||
[f"{w}x{h}" for w, h in sorted(supported_resolutions)]
|
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"Resolution {width}x{height} not supported. Supported: {supported_list}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_container_format(video: VideoInput) -> None:
|
|
||||||
"""Validates video container format is MP4."""
|
|
||||||
container_format = video.get_container_format()
|
|
||||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
|
||||||
raise ValueError(
|
|
||||||
f"Only MP4 container format supported. Got: {container_format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||||
@ -196,7 +116,7 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
|||||||
def _validate_minimum_duration(duration: float) -> None:
|
def _validate_minimum_duration(duration: float) -> None:
|
||||||
"""Ensures video is at least 5 seconds long."""
|
"""Ensures video is at least 5 seconds long."""
|
||||||
if duration < 5:
|
if duration < 5:
|
||||||
raise MoonvalleyApiError("Input video must be at least 5 seconds long.")
|
raise ValueError("Input video must be at least 5 seconds long.")
|
||||||
|
|
||||||
|
|
||||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||||
@ -206,127 +126,6 @@ def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
|||||||
return video
|
return video
|
||||||
|
|
||||||
|
|
||||||
def trim_video(video: VideoInput, duration_sec: float) -> VideoInput:
|
|
||||||
"""
|
|
||||||
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
|
||||||
using av to avoid loading entire video into memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
video: Input video to trim
|
|
||||||
duration_sec: Duration in seconds to keep from the beginning
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
VideoFromFile object that owns the output buffer
|
|
||||||
"""
|
|
||||||
output_buffer = io.BytesIO()
|
|
||||||
|
|
||||||
input_container = None
|
|
||||||
output_container = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Get the stream source - this avoids loading entire video into memory
|
|
||||||
# when the source is already a file path
|
|
||||||
input_source = video.get_stream_source()
|
|
||||||
|
|
||||||
# Open containers
|
|
||||||
input_container = av.open(input_source, mode="r")
|
|
||||||
output_container = av.open(output_buffer, mode="w", format="mp4")
|
|
||||||
|
|
||||||
# Set up output streams for re-encoding
|
|
||||||
video_stream = None
|
|
||||||
audio_stream = None
|
|
||||||
|
|
||||||
for stream in input_container.streams:
|
|
||||||
logging.info(f"Found stream: type={stream.type}, class={type(stream)}")
|
|
||||||
if isinstance(stream, av.VideoStream):
|
|
||||||
# Create output video stream with same parameters
|
|
||||||
video_stream = output_container.add_stream(
|
|
||||||
"h264", rate=stream.average_rate
|
|
||||||
)
|
|
||||||
video_stream.width = stream.width
|
|
||||||
video_stream.height = stream.height
|
|
||||||
video_stream.pix_fmt = "yuv420p"
|
|
||||||
logging.info(
|
|
||||||
f"Added video stream: {stream.width}x{stream.height} @ {stream.average_rate}fps"
|
|
||||||
)
|
|
||||||
elif isinstance(stream, av.AudioStream):
|
|
||||||
# Create output audio stream with same parameters
|
|
||||||
audio_stream = output_container.add_stream(
|
|
||||||
"aac", rate=stream.sample_rate
|
|
||||||
)
|
|
||||||
audio_stream.sample_rate = stream.sample_rate
|
|
||||||
audio_stream.layout = stream.layout
|
|
||||||
logging.info(
|
|
||||||
f"Added audio stream: {stream.sample_rate}Hz, {stream.channels} channels"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate target frame count that's divisible by 16
|
|
||||||
fps = input_container.streams.video[0].average_rate
|
|
||||||
estimated_frames = int(duration_sec * fps)
|
|
||||||
target_frames = (
|
|
||||||
estimated_frames // 16
|
|
||||||
) * 16 # Round down to nearest multiple of 16
|
|
||||||
|
|
||||||
if target_frames == 0:
|
|
||||||
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
|
||||||
|
|
||||||
frame_count = 0
|
|
||||||
audio_frame_count = 0
|
|
||||||
|
|
||||||
# Decode and re-encode video frames
|
|
||||||
if video_stream:
|
|
||||||
for frame in input_container.decode(video=0):
|
|
||||||
if frame_count >= target_frames:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Re-encode frame
|
|
||||||
for packet in video_stream.encode(frame):
|
|
||||||
output_container.mux(packet)
|
|
||||||
frame_count += 1
|
|
||||||
|
|
||||||
# Flush encoder
|
|
||||||
for packet in video_stream.encode():
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
logging.info(
|
|
||||||
f"Encoded {frame_count} video frames (target: {target_frames})"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Decode and re-encode audio frames
|
|
||||||
if audio_stream:
|
|
||||||
input_container.seek(0) # Reset to beginning for audio
|
|
||||||
for frame in input_container.decode(audio=0):
|
|
||||||
if frame.time >= duration_sec:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Re-encode frame
|
|
||||||
for packet in audio_stream.encode(frame):
|
|
||||||
output_container.mux(packet)
|
|
||||||
audio_frame_count += 1
|
|
||||||
|
|
||||||
# Flush encoder
|
|
||||||
for packet in audio_stream.encode():
|
|
||||||
output_container.mux(packet)
|
|
||||||
|
|
||||||
logging.info(f"Encoded {audio_frame_count} audio frames")
|
|
||||||
|
|
||||||
# Close containers
|
|
||||||
output_container.close()
|
|
||||||
input_container.close()
|
|
||||||
|
|
||||||
# Return as VideoFromFile using the buffer
|
|
||||||
output_buffer.seek(0)
|
|
||||||
return InputImpl.VideoFromFile(output_buffer)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Clean up on error
|
|
||||||
if input_container is not None:
|
|
||||||
input_container.close()
|
|
||||||
if output_container is not None:
|
|
||||||
output_container.close()
|
|
||||||
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
|
||||||
|
|
||||||
|
|
||||||
def parse_width_height_from_res(resolution: str):
|
def parse_width_height_from_res(resolution: str):
|
||||||
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
# Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict
|
||||||
res_map = {
|
res_map = {
|
||||||
@ -335,7 +134,7 @@ def parse_width_height_from_res(resolution: str):
|
|||||||
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
"1:1 (1152 x 1152)": {"width": 1152, "height": 1152},
|
||||||
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
"4:3 (1536 x 1152)": {"width": 1536, "height": 1152},
|
||||||
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
"3:4 (1152 x 1536)": {"width": 1152, "height": 1536},
|
||||||
"21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
# "21:9 (2560 x 1080)": {"width": 2560, "height": 1080},
|
||||||
}
|
}
|
||||||
return res_map.get(resolution, {"width": 1920, "height": 1080})
|
return res_map.get(resolution, {"width": 1920, "height": 1080})
|
||||||
|
|
||||||
@ -350,52 +149,47 @@ def parse_control_parameter(value):
|
|||||||
return control_map.get(value, control_map["Motion Transfer"])
|
return control_map.get(value, control_map["Motion Transfer"])
|
||||||
|
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse:
|
||||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
|
return await poll_op(
|
||||||
) -> MoonvalleyPromptResponse:
|
cls,
|
||||||
return await poll_until_finished(
|
ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"),
|
||||||
auth_kwargs,
|
response_model=MoonvalleyPromptResponse,
|
||||||
ApiEndpoint(
|
status_extractor=lambda r: (r.status if r and r.status else None),
|
||||||
path=f"{API_PROMPTS_ENDPOINT}/{task_id}",
|
poll_interval=16.0,
|
||||||
method=HttpMethod.GET,
|
max_poll_attempts=240,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=MoonvalleyPromptResponse,
|
|
||||||
),
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
node_id=node_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MoonvalleyImg2VideoNode",
|
node_id="MoonvalleyImg2VideoNode",
|
||||||
display_name="Moonvalley Marey Image to Video",
|
display_name="Moonvalley Marey Image to Video",
|
||||||
category="api node/video/Moonvalley Marey",
|
category="api node/video/Moonvalley Marey",
|
||||||
description="Moonvalley Marey Image to Video Node",
|
description="Moonvalley Marey Image to Video Node",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="The reference image used to generate the video",
|
tooltip="The reference image used to generate the video",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
tooltip="Negative prompt text",
|
tooltip="Negative prompt text",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
"16:9 (1920 x 1080)",
|
"16:9 (1920 x 1080)",
|
||||||
@ -403,42 +197,43 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
|||||||
"1:1 (1152 x 1152)",
|
"1:1 (1152 x 1152)",
|
||||||
"4:3 (1536 x 1152)",
|
"4:3 (1536 x 1152)",
|
||||||
"3:4 (1152 x 1536)",
|
"3:4 (1152 x 1536)",
|
||||||
"21:9 (2560 x 1080)",
|
# "21:9 (2560 x 1080)",
|
||||||
],
|
],
|
||||||
default="16:9 (1920 x 1080)",
|
default="16:9 (1920 x 1080)",
|
||||||
tooltip="Resolution of the output video",
|
tooltip="Resolution of the output video",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"prompt_adherence",
|
"prompt_adherence",
|
||||||
default=10.0,
|
default=4.5,
|
||||||
min=1.0,
|
min=1.0,
|
||||||
max=20.0,
|
max=20.0,
|
||||||
step=1.0,
|
step=1.0,
|
||||||
tooltip="Guidance scale for generation control",
|
tooltip="Guidance scale for generation control",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=9,
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
|
control_after_generate=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=100,
|
default=33,
|
||||||
min=1,
|
min=1,
|
||||||
max=100,
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
tooltip="Number of denoising steps",
|
tooltip="Number of denoising steps",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -453,22 +248,17 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
|||||||
prompt_adherence: float,
|
prompt_adherence: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
steps: int,
|
steps: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH)
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
|
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = parse_width_height_from_res(resolution)
|
width_height = parse_width_height_from_res(resolution)
|
||||||
|
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
guidance_scale=prompt_adherence,
|
guidance_scale=prompt_adherence,
|
||||||
num_frames=128,
|
|
||||||
width=width_height["width"],
|
width=width_height["width"],
|
||||||
height=width_height["height"],
|
height=width_height["height"],
|
||||||
use_negative_prompts=True,
|
use_negative_prompts=True,
|
||||||
@ -476,85 +266,69 @@ class MoonvalleyImg2VideoNode(comfy_io.ComfyNode):
|
|||||||
|
|
||||||
# Get MIME type from tensor - assuming PNG format for image tensors
|
# Get MIME type from tensor - assuming PNG format for image tensors
|
||||||
mime_type = "image/png"
|
mime_type = "image/png"
|
||||||
|
image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0]
|
||||||
image_url = (
|
task_creation_response = await sync_op(
|
||||||
await upload_images_to_comfyapi(
|
cls,
|
||||||
image, max_images=1, auth_kwargs=auth, mime_type=mime_type
|
endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"),
|
||||||
)
|
response_model=MoonvalleyPromptResponse,
|
||||||
)[0]
|
data=MoonvalleyTextToVideoRequest(
|
||||||
|
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
||||||
request = MoonvalleyTextToVideoRequest(
|
|
||||||
image_url=image_url, prompt_text=prompt, inference_params=inference_params
|
|
||||||
)
|
|
||||||
initial_operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=API_IMG2VIDEO_ENDPOINT,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=MoonvalleyTextToVideoRequest,
|
|
||||||
response_model=MoonvalleyPromptResponse,
|
|
||||||
),
|
),
|
||||||
request=request,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_creation_response = await initial_operation.execute()
|
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
final_response = await get_response(cls, task_creation_response.id)
|
||||||
|
|
||||||
final_response = await get_response(
|
|
||||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
|
||||||
)
|
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
video = await download_url_to_video_output(final_response.output_url)
|
||||||
return comfy_io.NodeOutput(video)
|
return IO.NodeOutput(video)
|
||||||
|
|
||||||
|
|
||||||
class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MoonvalleyVideo2VideoNode",
|
node_id="MoonvalleyVideo2VideoNode",
|
||||||
display_name="Moonvalley Marey Video to Video",
|
display_name="Moonvalley Marey Video to Video",
|
||||||
category="api node/video/Moonvalley Marey",
|
category="api node/video/Moonvalley Marey",
|
||||||
description="",
|
description="",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="Describes the video to generate",
|
tooltip="Describes the video to generate",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
tooltip="Negative prompt text",
|
tooltip="Negative prompt text",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=9,
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
control_after_generate=False,
|
control_after_generate=False,
|
||||||
),
|
),
|
||||||
comfy_io.Video.Input(
|
IO.Video.Input(
|
||||||
"video",
|
"video",
|
||||||
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
tooltip="The reference video used to generate the output video. Must be at least 5 seconds long. "
|
||||||
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
"Videos longer than 5s will be automatically trimmed. Only MP4 format supported.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"control_type",
|
"control_type",
|
||||||
options=["Motion Transfer", "Pose Transfer"],
|
options=["Motion Transfer", "Pose Transfer"],
|
||||||
default="Motion Transfer",
|
default="Motion Transfer",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"motion_intensity",
|
"motion_intensity",
|
||||||
default=100,
|
default=100,
|
||||||
min=0,
|
min=0,
|
||||||
@ -563,12 +337,21 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
|||||||
tooltip="Only used if control_type is 'Motion Transfer'",
|
tooltip="Only used if control_type is 'Motion Transfer'",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"steps",
|
||||||
|
default=33,
|
||||||
|
min=1,
|
||||||
|
max=100,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
tooltip="Number of inference steps",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -582,16 +365,13 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
|||||||
video: Optional[VideoInput] = None,
|
video: Optional[VideoInput] = None,
|
||||||
control_type: str = "Motion Transfer",
|
control_type: str = "Motion Transfer",
|
||||||
motion_intensity: Optional[int] = 100,
|
motion_intensity: Optional[int] = 100,
|
||||||
) -> comfy_io.NodeOutput:
|
steps=33,
|
||||||
auth = {
|
prompt_adherence=4.5,
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
) -> IO.NodeOutput:
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
validated_video = validate_video_to_video_input(video)
|
validated_video = validate_video_to_video_input(video)
|
||||||
video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth)
|
video_url = await upload_video_to_comfyapi(cls, validated_video)
|
||||||
|
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
validate_prompts(prompt, negative_prompt)
|
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
|
|
||||||
# Only include motion_intensity for Motion Transfer
|
# Only include motion_intensity for Motion Transfer
|
||||||
control_params = {}
|
control_params = {}
|
||||||
@ -602,65 +382,52 @@ class MoonvalleyVideo2VideoNode(comfy_io.ComfyNode):
|
|||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
control_params=control_params,
|
control_params=control_params,
|
||||||
|
steps=steps,
|
||||||
|
guidance_scale=prompt_adherence,
|
||||||
)
|
)
|
||||||
|
|
||||||
control = parse_control_parameter(control_type)
|
task_creation_response = await sync_op(
|
||||||
|
cls,
|
||||||
request = MoonvalleyVideoToVideoRequest(
|
endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"),
|
||||||
control_type=control,
|
response_model=MoonvalleyPromptResponse,
|
||||||
video_url=video_url,
|
data=MoonvalleyVideoToVideoRequest(
|
||||||
prompt_text=prompt,
|
control_type=parse_control_parameter(control_type),
|
||||||
inference_params=inference_params,
|
video_url=video_url,
|
||||||
)
|
prompt_text=prompt,
|
||||||
|
inference_params=inference_params,
|
||||||
initial_operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=API_VIDEO2VIDEO_ENDPOINT,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=MoonvalleyVideoToVideoRequest,
|
|
||||||
response_model=MoonvalleyPromptResponse,
|
|
||||||
),
|
),
|
||||||
request=request,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_creation_response = await initial_operation.execute()
|
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
final_response = await get_response(cls, task_creation_response.id)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||||
final_response = await get_response(
|
|
||||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
|
||||||
)
|
|
||||||
|
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
|
||||||
return comfy_io.NodeOutput(video)
|
|
||||||
|
|
||||||
|
|
||||||
class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
class MoonvalleyTxt2VideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="MoonvalleyTxt2VideoNode",
|
node_id="MoonvalleyTxt2VideoNode",
|
||||||
display_name="Moonvalley Marey Text to Video",
|
display_name="Moonvalley Marey Text to Video",
|
||||||
category="api node/video/Moonvalley Marey",
|
category="api node/video/Moonvalley Marey",
|
||||||
description="",
|
description="",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
default="<synthetic> <scene cut> gopro, bright, contrast, static, overexposed, vignette, "
|
||||||
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
"artifacts, still, noise, texture, scanlines, videogame, 360 camera, VR, transition, "
|
||||||
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
"flare, saturation, distorted, warped, wide angle, saturated, vibrant, glowing, "
|
||||||
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
"cross dissolve, cheesy, ugly hands, mutated hands, mutant, disfigured, extra fingers, "
|
||||||
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
"blown out, horrible, blurry, worst quality, bad, dissolve, melt, fade in, fade out, "
|
||||||
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
"wobbly, weird, low quality, plastic, stock footage, video camera, boring",
|
||||||
tooltip="Negative prompt text",
|
tooltip="Negative prompt text",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
"16:9 (1920 x 1080)",
|
"16:9 (1920 x 1080)",
|
||||||
@ -673,37 +440,38 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
|||||||
default="16:9 (1920 x 1080)",
|
default="16:9 (1920 x 1080)",
|
||||||
tooltip="Resolution of the output video",
|
tooltip="Resolution of the output video",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"prompt_adherence",
|
"prompt_adherence",
|
||||||
default=10.0,
|
default=4.0,
|
||||||
min=1.0,
|
min=1.0,
|
||||||
max=20.0,
|
max=20.0,
|
||||||
step=1.0,
|
step=1.0,
|
||||||
tooltip="Guidance scale for generation control",
|
tooltip="Guidance scale for generation control",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=9,
|
default=9,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
tooltip="Random seed value",
|
tooltip="Random seed value",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=100,
|
default=33,
|
||||||
min=1,
|
min=1,
|
||||||
max=100,
|
max=100,
|
||||||
step=1,
|
step=1,
|
||||||
tooltip="Inference steps",
|
tooltip="Inference steps",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -717,15 +485,11 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
|||||||
prompt_adherence: float,
|
prompt_adherence: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
steps: int,
|
steps: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
|
validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH)
|
||||||
width_height = parse_width_height_from_res(resolution)
|
width_height = parse_width_height_from_res(resolution)
|
||||||
|
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
inference_params = MoonvalleyTextToVideoInferenceParams(
|
inference_params = MoonvalleyTextToVideoInferenceParams(
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
steps=steps,
|
steps=steps,
|
||||||
@ -735,35 +499,21 @@ class MoonvalleyTxt2VideoNode(comfy_io.ComfyNode):
|
|||||||
width=width_height["width"],
|
width=width_height["width"],
|
||||||
height=width_height["height"],
|
height=width_height["height"],
|
||||||
)
|
)
|
||||||
request = MoonvalleyTextToVideoRequest(
|
|
||||||
prompt_text=prompt, inference_params=inference_params
|
|
||||||
)
|
|
||||||
|
|
||||||
init_op = SynchronousOperation(
|
task_creation_response = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path=API_TXT2VIDEO_ENDPOINT,
|
endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=MoonvalleyPromptResponse,
|
||||||
request_model=MoonvalleyTextToVideoRequest,
|
data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params),
|
||||||
response_model=MoonvalleyPromptResponse,
|
|
||||||
),
|
|
||||||
request=request,
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
task_creation_response = await init_op.execute()
|
|
||||||
validate_task_creation_response(task_creation_response)
|
validate_task_creation_response(task_creation_response)
|
||||||
task_id = task_creation_response.id
|
final_response = await get_response(cls, task_creation_response.id)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(final_response.output_url))
|
||||||
final_response = await get_response(
|
|
||||||
task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id
|
|
||||||
)
|
|
||||||
|
|
||||||
video = await download_url_to_video_output(final_response.output_url)
|
|
||||||
return comfy_io.NodeOutput(video)
|
|
||||||
|
|
||||||
|
|
||||||
class MoonvalleyExtension(ComfyExtension):
|
class MoonvalleyExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
MoonvalleyImg2VideoNode,
|
MoonvalleyImg2VideoNode,
|
||||||
MoonvalleyTxt2VideoNode,
|
MoonvalleyTxt2VideoNode,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -7,40 +7,23 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, TypeVar
|
from typing import Optional
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api_nodes.apis import pika_api as pika_defs
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
validate_string,
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis import (
|
|
||||||
PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
|
|
||||||
PikaGenerateResponse,
|
|
||||||
PikaVideoResponse,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
HttpMethod,
|
poll_op,
|
||||||
PollingOperation,
|
|
||||||
SynchronousOperation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
|
||||||
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
|
||||||
@ -55,152 +38,58 @@ PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
|
|||||||
PATH_VIDEO_GET = "/proxy/pika/videos"
|
PATH_VIDEO_GET = "/proxy/pika/videos"
|
||||||
|
|
||||||
|
|
||||||
class PikaDurationEnum(int, Enum):
|
|
||||||
integer_5 = 5
|
|
||||||
integer_10 = 10
|
|
||||||
|
|
||||||
|
|
||||||
class PikaResolutionEnum(str, Enum):
|
|
||||||
field_1080p = "1080p"
|
|
||||||
field_720p = "720p"
|
|
||||||
|
|
||||||
|
|
||||||
class Pikaffect(str, Enum):
|
|
||||||
Cake_ify = "Cake-ify"
|
|
||||||
Crumble = "Crumble"
|
|
||||||
Crush = "Crush"
|
|
||||||
Decapitate = "Decapitate"
|
|
||||||
Deflate = "Deflate"
|
|
||||||
Dissolve = "Dissolve"
|
|
||||||
Explode = "Explode"
|
|
||||||
Eye_pop = "Eye-pop"
|
|
||||||
Inflate = "Inflate"
|
|
||||||
Levitate = "Levitate"
|
|
||||||
Melt = "Melt"
|
|
||||||
Peel = "Peel"
|
|
||||||
Poke = "Poke"
|
|
||||||
Squish = "Squish"
|
|
||||||
Ta_da = "Ta-da"
|
|
||||||
Tear = "Tear"
|
|
||||||
|
|
||||||
|
|
||||||
class PikaApiError(Exception):
|
|
||||||
"""Exception for Pika API errors."""
|
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_video_response(response: PikaVideoResponse) -> bool:
|
|
||||||
"""Check if the video response is valid."""
|
|
||||||
return hasattr(response, "url") and response.url is not None
|
|
||||||
|
|
||||||
|
|
||||||
def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
|
|
||||||
"""Check if the initial response is valid."""
|
|
||||||
return hasattr(response, "video_id") and response.video_id is not None
|
|
||||||
|
|
||||||
|
|
||||||
async def poll_for_task_status(
|
|
||||||
task_id: str,
|
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> PikaGenerateResponse:
|
|
||||||
polling_operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"{PATH_VIDEO_GET}/{task_id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PikaVideoResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[
|
|
||||||
"finished",
|
|
||||||
],
|
|
||||||
failed_statuses=["failed", "cancelled"],
|
|
||||||
status_extractor=lambda response: (
|
|
||||||
response.status.value if response.status else None
|
|
||||||
),
|
|
||||||
progress_extractor=lambda response: (
|
|
||||||
response.progress if hasattr(response, "progress") else None
|
|
||||||
),
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=lambda response: (
|
|
||||||
response.url if hasattr(response, "url") else None
|
|
||||||
),
|
|
||||||
node_id=node_id,
|
|
||||||
estimated_duration=60
|
|
||||||
)
|
|
||||||
return await polling_operation.execute()
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_task(
|
async def execute_task(
|
||||||
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
|
task_id: str,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
cls: type[IO.ComfyNode],
|
||||||
node_id: Optional[str] = None,
|
) -> IO.NodeOutput:
|
||||||
) -> tuple[VideoFromFile]:
|
final_response: pika_defs.PikaVideoResponse = await poll_op(
|
||||||
"""Executes the initial operation then polls for the task status until it is completed.
|
cls,
|
||||||
|
ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
|
||||||
Args:
|
response_model=pika_defs.PikaVideoResponse,
|
||||||
initial_operation: The initial operation to execute.
|
status_extractor=lambda response: (response.status.value if response.status else None),
|
||||||
auth_kwargs: The authentication token(s) to use for the API call.
|
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
|
||||||
|
estimated_duration=60,
|
||||||
Returns:
|
max_poll_attempts=240,
|
||||||
A tuple containing the video file as a VIDEO output.
|
)
|
||||||
"""
|
if not final_response.url:
|
||||||
initial_response = await initial_operation.execute()
|
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
|
||||||
if not is_valid_initial_response(initial_response):
|
|
||||||
error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
|
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise PikaApiError(error_msg)
|
raise Exception(error_msg)
|
||||||
|
video_url = final_response.url
|
||||||
task_id = initial_response.video_id
|
|
||||||
final_response = await poll_for_task_status(task_id, auth_kwargs, node_id=node_id)
|
|
||||||
if not is_valid_video_response(final_response):
|
|
||||||
error_msg = (
|
|
||||||
f"Pika task {task_id} succeeded but no video data found in response."
|
|
||||||
)
|
|
||||||
logging.error(error_msg)
|
|
||||||
raise PikaApiError(error_msg)
|
|
||||||
|
|
||||||
video_url = str(final_response.url)
|
|
||||||
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(video_url))
|
||||||
return (await download_url_to_video_output(video_url),)
|
|
||||||
|
|
||||||
|
|
||||||
def get_base_inputs_types() -> list[comfy_io.Input]:
|
def get_base_inputs_types() -> list[IO.Input]:
|
||||||
"""Get the base required inputs types common to all Pika nodes."""
|
"""Get the base required inputs types common to all Pika nodes."""
|
||||||
return [
|
return [
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input("resolution", options=["1080p", "720p"], default="1080p"),
|
||||||
"resolution", options=[resolution.value for resolution in PikaResolutionEnum], default="1080p"
|
IO.Combo.Input("duration", options=[5, 10], default=5),
|
||||||
),
|
|
||||||
comfy_io.Combo.Input(
|
|
||||||
"duration", options=[duration.value for duration in PikaDurationEnum], default=5
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
class PikaImageToVideo(IO.ComfyNode):
|
||||||
"""Pika 2.2 Image to Video Node."""
|
"""Pika 2.2 Image to Video Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PikaImageToVideoNode2_2",
|
node_id="PikaImageToVideoNode2_2",
|
||||||
display_name="Pika Image to Video",
|
display_name="Pika Image to Video",
|
||||||
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
description="Sends an image and prompt to the Pika API v2.2 to generate a video.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image", tooltip="The image to convert to video"),
|
IO.Image.Input("image", tooltip="The image to convert to video"),
|
||||||
*get_base_inputs_types(),
|
*get_base_inputs_types(),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -214,53 +103,40 @@ class PikaImageToVideoV2_2(comfy_io.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
class PikaTextToVideoNode(IO.ComfyNode):
|
||||||
"""Pika Text2Video v2.2 Node."""
|
"""Pika Text2Video v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PikaTextToVideoNode2_2",
|
node_id="PikaTextToVideoNode2_2",
|
||||||
display_name="Pika Text to Video",
|
display_name="Pika Text to Video",
|
||||||
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
description="Sends a text prompt to the Pika API v2.2 to generate a video.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
*get_base_inputs_types(),
|
*get_base_inputs_types(),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
step=0.001,
|
step=0.001,
|
||||||
min=0.4,
|
min=0.4,
|
||||||
@ -269,11 +145,11 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
tooltip="Aspect ratio (width / height)",
|
tooltip="Aspect ratio (width / height)",
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -287,19 +163,12 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
aspect_ratio: float,
|
aspect_ratio: float,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_TEXT_TO_VIDEO,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=PikaBodyGenerate22T2vGenerate22T2vPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -307,30 +176,29 @@ class PikaTextToVideoNodeV2_2(comfy_io.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
content_type="application/x-www-form-urlencoded",
|
content_type="application/x-www-form-urlencoded",
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaScenesV2_2(comfy_io.ComfyNode):
|
class PikaScenes(IO.ComfyNode):
|
||||||
"""PikaScenes v2.2 Node."""
|
"""PikaScenes v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PikaScenesV2_2",
|
node_id="PikaScenesV2_2",
|
||||||
display_name="Pika Scenes (Video Image Composition)",
|
display_name="Pika Scenes (Video Image Composition)",
|
||||||
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
description="Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
*get_base_inputs_types(),
|
*get_base_inputs_types(),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ingredients_mode",
|
"ingredients_mode",
|
||||||
options=["creative", "precise"],
|
options=["creative", "precise"],
|
||||||
default="creative",
|
default="creative",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
step=0.001,
|
step=0.001,
|
||||||
min=0.4,
|
min=0.4,
|
||||||
@ -338,37 +206,37 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
default=1.7777777777777777,
|
default=1.7777777777777777,
|
||||||
tooltip="Aspect ratio (width / height)",
|
tooltip="Aspect ratio (width / height)",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image_ingredient_1",
|
"image_ingredient_1",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image_ingredient_2",
|
"image_ingredient_2",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image_ingredient_3",
|
"image_ingredient_3",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image_ingredient_4",
|
"image_ingredient_4",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image_ingredient_5",
|
"image_ingredient_5",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Image that will be used as ingredient to create a video.",
|
tooltip="Image that will be used as ingredient to create a video.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -388,8 +256,7 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_3: Optional[torch.Tensor] = None,
|
image_ingredient_3: Optional[torch.Tensor] = None,
|
||||||
image_ingredient_4: Optional[torch.Tensor] = None,
|
image_ingredient_4: Optional[torch.Tensor] = None,
|
||||||
image_ingredient_5: Optional[torch.Tensor] = None,
|
image_ingredient_5: Optional[torch.Tensor] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
# Convert all passed images to BytesIO
|
|
||||||
all_image_bytes_io = []
|
all_image_bytes_io = []
|
||||||
for image in [
|
for image in [
|
||||||
image_ingredient_1,
|
image_ingredient_1,
|
||||||
@ -399,16 +266,14 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
image_ingredient_5,
|
image_ingredient_5,
|
||||||
]:
|
]:
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
all_image_bytes_io.append(tensor_to_bytesio(image))
|
||||||
image_bytes_io.seek(0)
|
|
||||||
all_image_bytes_io.append(image_bytes_io)
|
|
||||||
|
|
||||||
pika_files = [
|
pika_files = [
|
||||||
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
("images", (f"image_{i}.png", image_bytes_io, "image/png"))
|
||||||
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
for i, image_bytes_io in enumerate(all_image_bytes_io)
|
||||||
]
|
]
|
||||||
|
|
||||||
pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
pika_request_data = pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost(
|
||||||
ingredientsMode=ingredients_mode,
|
ingredientsMode=ingredients_mode,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -417,53 +282,45 @@ class PikaScenesV2_2(comfy_io.ComfyNode):
|
|||||||
duration=duration,
|
duration=duration,
|
||||||
aspectRatio=aspect_ratio,
|
aspectRatio=aspect_ratio,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKASCENES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikAdditionsNode(comfy_io.ComfyNode):
|
class PikAdditionsNode(IO.ComfyNode):
|
||||||
"""Pika Pikadditions Node. Add an image into a video."""
|
"""Pika Pikadditions Node. Add an image into a video."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Pikadditions",
|
node_id="Pikadditions",
|
||||||
display_name="Pikadditions (Video Object Insertion)",
|
display_name="Pikadditions (Video Object Insertion)",
|
||||||
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
description="Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Video.Input("video", tooltip="The video to add an image to."),
|
IO.Video.Input("video", tooltip="The video to add an image to."),
|
||||||
comfy_io.Image.Input("image", tooltip="The image to add to the video."),
|
IO.Image.Input("image", tooltip="The image to add to the video."),
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
min=0,
|
min=0,
|
||||||
max=0xFFFFFFFF,
|
max=0xFFFFFFFF,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -476,70 +333,70 @@ class PikAdditionsNode(comfy_io.ComfyNode):
|
|||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
# Convert video to BytesIO
|
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
image_bytes_io = tensor_to_bytesio(image)
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
"image": ("image.png", image_bytes_io, "image/png"),
|
||||||
}
|
}
|
||||||
|
pika_request_data = pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
||||||
# Prepare non-file data
|
|
||||||
pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKADDITIONS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaSwapsNode(comfy_io.ComfyNode):
|
class PikaSwapsNode(IO.ComfyNode):
|
||||||
"""Pika Pikaswaps Node."""
|
"""Pika Pikaswaps Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Pikaswaps",
|
node_id="Pikaswaps",
|
||||||
display_name="Pika Swaps (Video Object Replacement)",
|
display_name="Pika Swaps (Video Object Replacement)",
|
||||||
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
description="Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Video.Input("video", tooltip="The video to swap an object in."),
|
IO.Video.Input("video", tooltip="The video to swap an object in."),
|
||||||
comfy_io.Image.Input("image", tooltip="The image used to replace the masked object in the video."),
|
IO.Image.Input(
|
||||||
comfy_io.Mask.Input("mask", tooltip="Use the mask to define areas in the video to replace"),
|
"image",
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
tooltip="The image used to replace the masked object in the video.",
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
optional=True,
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
),
|
||||||
|
IO.Mask.Input(
|
||||||
|
"mask",
|
||||||
|
tooltip="Use the mask to define areas in the video to replace.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.String.Input("prompt_text", multiline=True, optional=True),
|
||||||
|
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||||
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True, optional=True),
|
||||||
|
IO.String.Input(
|
||||||
|
"region_to_modify",
|
||||||
|
multiline=True,
|
||||||
|
optional=True,
|
||||||
|
tooltip="Plaintext description of the object / region to modify.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -548,85 +405,65 @@ class PikaSwapsNode(comfy_io.ComfyNode):
|
|||||||
async def execute(
|
async def execute(
|
||||||
cls,
|
cls,
|
||||||
video: VideoInput,
|
video: VideoInput,
|
||||||
image: torch.Tensor,
|
image: Optional[torch.Tensor] = None,
|
||||||
mask: torch.Tensor,
|
mask: Optional[torch.Tensor] = None,
|
||||||
prompt_text: str,
|
prompt_text: str = "",
|
||||||
negative_prompt: str,
|
negative_prompt: str = "",
|
||||||
seed: int,
|
seed: int = 0,
|
||||||
) -> comfy_io.NodeOutput:
|
region_to_modify: str = "",
|
||||||
# Convert video to BytesIO
|
) -> IO.NodeOutput:
|
||||||
video_bytes_io = BytesIO()
|
video_bytes_io = BytesIO()
|
||||||
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||||
video_bytes_io.seek(0)
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
# Convert mask to binary mask with three channels
|
|
||||||
mask = torch.round(mask)
|
|
||||||
mask = mask.repeat(1, 3, 1, 1)
|
|
||||||
|
|
||||||
# Convert 3-channel binary mask to BytesIO
|
|
||||||
mask_bytes_io = BytesIO()
|
|
||||||
mask_bytes_io.write(mask.numpy().astype(np.uint8))
|
|
||||||
mask_bytes_io.seek(0)
|
|
||||||
|
|
||||||
# Convert image to BytesIO
|
|
||||||
image_bytes_io = tensor_to_bytesio(image)
|
|
||||||
image_bytes_io.seek(0)
|
|
||||||
|
|
||||||
pika_files = {
|
pika_files = {
|
||||||
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
"video": ("video.mp4", video_bytes_io, "video/mp4"),
|
||||||
"image": ("image.png", image_bytes_io, "image/png"),
|
|
||||||
"modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"),
|
|
||||||
}
|
}
|
||||||
|
if mask is not None:
|
||||||
|
pika_files["modifyRegionMask"] = ("mask.png", tensor_to_bytesio(mask), "image/png")
|
||||||
|
if image is not None:
|
||||||
|
pika_files["image"] = ("image.png", tensor_to_bytesio(image), "image/png")
|
||||||
|
|
||||||
# Prepare non-file data
|
pika_request_data = pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
||||||
pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
modifyRegionRoi=region_to_modify if region_to_modify else None,
|
||||||
)
|
)
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_request_data,
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKADDITIONS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=pika_request_data,
|
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaffectsNode(comfy_io.ComfyNode):
|
class PikaffectsNode(IO.ComfyNode):
|
||||||
"""Pika Pikaffects Node."""
|
"""Pika Pikaffects Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Pikaffects",
|
node_id="Pikaffects",
|
||||||
display_name="Pikaffects (Video Effects)",
|
display_name="Pikaffects (Video Effects)",
|
||||||
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
description="Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
IO.Image.Input("image", tooltip="The reference image to apply the Pikaffect to."),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"pikaffect", options=[pikaffect.value for pikaffect in Pikaffect], default="Cake-ify"
|
"pikaffect", options=pika_defs.Pikaffect, default="Cake-ify"
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt_text", multiline=True),
|
IO.String.Input("prompt_text", multiline=True),
|
||||||
comfy_io.String.Input("negative_prompt", multiline=True),
|
IO.String.Input("negative_prompt", multiline=True),
|
||||||
comfy_io.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
IO.Int.Input("seed", min=0, max=0xFFFFFFFF, control_after_generate=True),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -639,19 +476,12 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
prompt_text: str,
|
prompt_text: str,
|
||||||
negative_prompt: str,
|
negative_prompt: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFFECTS,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
|
|
||||||
pikaffect=pikaffect,
|
pikaffect=pikaffect,
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
@ -659,31 +489,30 @@ class PikaffectsNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
class PikaStartEndFrameNode(IO.ComfyNode):
|
||||||
"""PikaFrames v2.2 Node."""
|
"""PikaFrames v2.2 Node."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PikaStartEndFrameNode2_2",
|
node_id="PikaStartEndFrameNode2_2",
|
||||||
display_name="Pika Start and End Frame to Video",
|
display_name="Pika Start and End Frame to Video",
|
||||||
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
description="Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them.",
|
||||||
category="api node/video/Pika",
|
category="api node/video/Pika",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image_start", tooltip="The first image to combine."),
|
IO.Image.Input("image_start", tooltip="The first image to combine."),
|
||||||
comfy_io.Image.Input("image_end", tooltip="The last image to combine."),
|
IO.Image.Input("image_end", tooltip="The last image to combine."),
|
||||||
*get_base_inputs_types(),
|
*get_base_inputs_types(),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -698,23 +527,17 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
duration: int,
|
duration: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt_text, field_name="prompt_text", min_length=1)
|
||||||
pika_files = [
|
pika_files = [
|
||||||
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
|
||||||
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
|
||||||
]
|
]
|
||||||
auth = {
|
initial_operation = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
|
||||||
}
|
response_model=pika_defs.PikaGenerateResponse,
|
||||||
initial_operation = SynchronousOperation(
|
data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=PATH_PIKAFRAMES,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
|
|
||||||
response_model=PikaGenerateResponse,
|
|
||||||
),
|
|
||||||
request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
|
|
||||||
promptText=prompt_text,
|
promptText=prompt_text,
|
||||||
negativePrompt=negative_prompt,
|
negativePrompt=negative_prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -723,22 +546,21 @@ class PikaStartEndFrameNode2_2(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=pika_files,
|
files=pika_files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id)
|
return await execute_task(initial_operation.video_id, cls)
|
||||||
|
|
||||||
|
|
||||||
class PikaApiNodesExtension(ComfyExtension):
|
class PikaApiNodesExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
PikaImageToVideoV2_2,
|
PikaImageToVideo,
|
||||||
PikaTextToVideoNodeV2_2,
|
PikaTextToVideoNode,
|
||||||
PikaScenesV2_2,
|
PikaScenes,
|
||||||
PikAdditionsNode,
|
PikAdditionsNode,
|
||||||
PikaSwapsNode,
|
PikaSwapsNode,
|
||||||
PikaffectsNode,
|
PikaffectsNode,
|
||||||
PikaStartEndFrameNode2_2,
|
PikaStartEndFrameNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from inspect import cleandoc
|
import torch
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from io import BytesIO
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.pixverse_api import (
|
from comfy_api_nodes.apis.pixverse_api import (
|
||||||
PixverseTextVideoRequest,
|
PixverseTextVideoRequest,
|
||||||
PixverseImageVideoRequest,
|
PixverseImageVideoRequest,
|
||||||
@ -17,125 +16,91 @@ from comfy_api_nodes.apis.pixverse_api import (
|
|||||||
PixverseIO,
|
PixverseIO,
|
||||||
pixverse_templates,
|
pixverse_templates,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
validate_string,
|
validate_string,
|
||||||
)
|
)
|
||||||
from comfy_api.input_impl import VideoFromFile
|
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import aiohttp
|
|
||||||
|
|
||||||
|
|
||||||
AVERAGE_DURATION_T2V = 32
|
AVERAGE_DURATION_T2V = 32
|
||||||
AVERAGE_DURATION_I2V = 30
|
AVERAGE_DURATION_I2V = 30
|
||||||
AVERAGE_DURATION_T2T = 52
|
AVERAGE_DURATION_T2T = 52
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(
|
async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor):
|
||||||
response: PixverseGenerationStatusResponse,
|
response_upload = await sync_op(
|
||||||
) -> Optional[str]:
|
cls,
|
||||||
if response.Resp is None or response.Resp.url is None:
|
ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"),
|
||||||
return None
|
response_model=PixverseImageUploadResponse,
|
||||||
return str(response.Resp.url)
|
files={"image": tensor_to_bytesio(image)},
|
||||||
|
|
||||||
|
|
||||||
async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
|
|
||||||
# first, upload image to Pixverse and get image id to use in actual generation call
|
|
||||||
files = {"image": tensor_to_bytesio(image)}
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/image/upload",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseImageUploadResponse,
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files=files,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
response_upload: PixverseImageUploadResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_upload.Resp is None:
|
if response_upload.Resp is None:
|
||||||
raise Exception(
|
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
|
||||||
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
return response_upload.Resp.img_id
|
return response_upload.Resp.img_id
|
||||||
|
|
||||||
|
|
||||||
class PixverseTemplateNode(comfy_io.ComfyNode):
|
class PixverseTemplateNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Select template for PixVerse Video generation.
|
Select template for PixVerse Video generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTemplateNode",
|
node_id="PixverseTemplateNode",
|
||||||
display_name="PixVerse Template",
|
display_name="PixVerse Template",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input("template", options=[list(pixverse_templates.keys())]),
|
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
|
outputs=[IO.Custom(PixverseIO.TEMPLATE).Output(display_name="pixverse_template")],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, template: str) -> comfy_io.NodeOutput:
|
def execute(cls, template: str) -> IO.NodeOutput:
|
||||||
template_id = pixverse_templates.get(template, None)
|
template_id = pixverse_templates.get(template, None)
|
||||||
if template_id is None:
|
if template_id is None:
|
||||||
raise Exception(f"Template '{template}' is not recognized.")
|
raise Exception(f"Template '{template}' is not recognized.")
|
||||||
# just return the integer
|
return IO.NodeOutput(template_id)
|
||||||
return comfy_io.NodeOutput(template_id)
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
class PixverseTextToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTextToVideoNode",
|
node_id="PixverseTextToVideoNode",
|
||||||
display_name="PixVerse Text to Video",
|
display_name="PixVerse Text to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the video generation",
|
tooltip="Prompt for the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[ratio.value for ratio in PixverseAspectRatio],
|
options=PixverseAspectRatio,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"quality",
|
"quality",
|
||||||
options=[resolution.value for resolution in PixverseQuality],
|
options=PixverseQuality,
|
||||||
default=PixverseQuality.res_540p,
|
default=PixverseQuality.res_540p,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
options=[dur.value for dur in PixverseDuration],
|
options=PixverseDuration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"motion_mode",
|
"motion_mode",
|
||||||
options=[mode.value for mode in PixverseMotionMode],
|
options=PixverseMotionMode,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -143,24 +108,24 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation.",
|
tooltip="Seed for video generation.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="An optional text description of undesired elements on an image.",
|
tooltip="An optional text description of undesired elements on an image.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
|
IO.Custom(PixverseIO.TEMPLATE).Input(
|
||||||
"pixverse_template",
|
"pixverse_template",
|
||||||
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -176,8 +141,8 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
if quality == PixverseQuality.res_1080p:
|
if quality == PixverseQuality.res_1080p:
|
||||||
@ -186,18 +151,11 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"),
|
||||||
}
|
response_model=PixverseVideoResponse,
|
||||||
operation = SynchronousOperation(
|
data=PixverseTextVideoRequest(
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/pixverse/video/text/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=PixverseTextVideoRequest,
|
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseTextVideoRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -207,20 +165,14 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -228,52 +180,41 @@ class PixverseTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
class PixverseImageToVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseImageToVideoNode",
|
node_id="PixverseImageToVideoNode",
|
||||||
display_name="PixVerse Image to Video",
|
display_name="PixVerse Image to Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the video generation",
|
tooltip="Prompt for the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"quality",
|
"quality",
|
||||||
options=[resolution.value for resolution in PixverseQuality],
|
options=PixverseQuality,
|
||||||
default=PixverseQuality.res_540p,
|
default=PixverseQuality.res_540p,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
options=[dur.value for dur in PixverseDuration],
|
options=PixverseDuration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"motion_mode",
|
"motion_mode",
|
||||||
options=[mode.value for mode in PixverseMotionMode],
|
options=PixverseMotionMode,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -281,24 +222,24 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation.",
|
tooltip="Seed for video generation.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="An optional text description of undesired elements on an image.",
|
tooltip="An optional text description of undesired elements on an image.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Custom(PixverseIO.TEMPLATE).Input(
|
IO.Custom(PixverseIO.TEMPLATE).Input(
|
||||||
"pixverse_template",
|
"pixverse_template",
|
||||||
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
tooltip="An optional template to influence style of generation, created by the PixVerse Template node.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -314,13 +255,9 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
seed,
|
seed,
|
||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
pixverse_template: int = None,
|
pixverse_template: int = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
img_id = await upload_image_to_pixverse(cls, image)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
img_id = await upload_image_to_pixverse(image, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -330,14 +267,11 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/img/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=PixverseVideoResponse,
|
||||||
request_model=PixverseImageVideoRequest,
|
data=PixverseImageVideoRequest(
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseImageVideoRequest(
|
|
||||||
img_id=img_id,
|
img_id=img_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
quality=quality,
|
quality=quality,
|
||||||
@ -347,20 +281,15 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
template_id=pixverse_template,
|
template_id=pixverse_template,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -368,53 +297,42 @@ class PixverseImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_I2V,
|
estimated_duration=AVERAGE_DURATION_I2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||||
"""
|
|
||||||
Generates videos based on prompt and output_size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="PixverseTransitionVideoNode",
|
node_id="PixverseTransitionVideoNode",
|
||||||
display_name="PixVerse Transition Video",
|
display_name="PixVerse Transition Video",
|
||||||
category="api node/video/PixVerse",
|
category="api node/video/PixVerse",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description="Generates videos based on prompt and output_size.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("first_frame"),
|
IO.Image.Input("first_frame"),
|
||||||
comfy_io.Image.Input("last_frame"),
|
IO.Image.Input("last_frame"),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt for the video generation",
|
tooltip="Prompt for the video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"quality",
|
"quality",
|
||||||
options=[resolution.value for resolution in PixverseQuality],
|
options=PixverseQuality,
|
||||||
default=PixverseQuality.res_540p,
|
default=PixverseQuality.res_540p,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
options=[dur.value for dur in PixverseDuration],
|
options=PixverseDuration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"motion_mode",
|
"motion_mode",
|
||||||
options=[mode.value for mode in PixverseMotionMode],
|
options=PixverseMotionMode,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
@ -422,7 +340,7 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation.",
|
tooltip="Seed for video generation.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -430,11 +348,11 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.Video.Output()],
|
outputs=[IO.Video.Output()],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -450,14 +368,10 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
motion_mode: str,
|
motion_mode: str,
|
||||||
seed,
|
seed,
|
||||||
negative_prompt: str = None,
|
negative_prompt: str = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
auth = {
|
first_frame_id = await upload_image_to_pixverse(cls, first_frame)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
last_frame_id = await upload_image_to_pixverse(cls, last_frame)
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth)
|
|
||||||
last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth)
|
|
||||||
|
|
||||||
# 1080p is limited to 5 seconds duration
|
# 1080p is limited to 5 seconds duration
|
||||||
# only normal motion_mode supported for 1080p or for non-5 second duration
|
# only normal motion_mode supported for 1080p or for non-5 second duration
|
||||||
@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
elif duration_seconds != PixverseDuration.dur_5:
|
elif duration_seconds != PixverseDuration.dur_5:
|
||||||
motion_mode = PixverseMotionMode.normal
|
motion_mode = PixverseMotionMode.normal
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/pixverse/video/transition/generate",
|
ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=PixverseVideoResponse,
|
||||||
request_model=PixverseTransitionVideoRequest,
|
data=PixverseTransitionVideoRequest(
|
||||||
response_model=PixverseVideoResponse,
|
|
||||||
),
|
|
||||||
request=PixverseTransitionVideoRequest(
|
|
||||||
first_frame_img=first_frame_id,
|
first_frame_img=first_frame_id,
|
||||||
last_frame_img=last_frame_id,
|
last_frame_img=last_frame_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
negative_prompt=negative_prompt if negative_prompt else None,
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.Resp is None:
|
if response_api.Resp is None:
|
||||||
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
|
ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=PixverseGenerationStatusResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=PixverseGenerationStatusResponse,
|
|
||||||
),
|
|
||||||
completed_statuses=[PixverseStatus.successful],
|
completed_statuses=[PixverseStatus.successful],
|
||||||
failed_statuses=[
|
failed_statuses=[
|
||||||
PixverseStatus.contents_moderation,
|
PixverseStatus.contents_moderation,
|
||||||
@ -505,21 +411,14 @@ class PixverseTransitionVideoNode(comfy_io.ComfyNode):
|
|||||||
PixverseStatus.deleted,
|
PixverseStatus.deleted,
|
||||||
],
|
],
|
||||||
status_extractor=lambda x: x.Resp.status,
|
status_extractor=lambda x: x.Resp.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2V,
|
estimated_duration=AVERAGE_DURATION_T2V,
|
||||||
)
|
)
|
||||||
response_poll = await operation.execute()
|
return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url))
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(response_poll.Resp.url) as vid_response:
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read())))
|
|
||||||
|
|
||||||
|
|
||||||
class PixVerseExtension(ComfyExtension):
|
class PixVerseExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
PixverseTextToVideoNode,
|
PixverseTextToVideoNode,
|
||||||
PixverseImageToVideoNode,
|
PixverseImageToVideoNode,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -5,12 +5,9 @@ Rodin API docs: https://developer.hyper3d.ai/
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
from inspect import cleandoc
|
from inspect import cleandoc
|
||||||
import folder_paths as comfy_paths
|
import folder_paths as comfy_paths
|
||||||
import aiohttp
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -26,26 +23,26 @@ from comfy_api_nodes.apis.rodin_api import (
|
|||||||
Rodin3DDownloadResponse,
|
Rodin3DDownloadResponse,
|
||||||
JobStatus,
|
JobStatus,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_bytesio,
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
)
|
)
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
|
||||||
|
|
||||||
COMMON_PARAMETERS = [
|
COMMON_PARAMETERS = [
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"Seed",
|
"Seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=65535,
|
max=65535,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"Polygon_count",
|
"Polygon_count",
|
||||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
|
||||||
default="18K-Quad",
|
default="18K-Quad",
|
||||||
@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048):
|
|||||||
|
|
||||||
|
|
||||||
async def create_generate_task(
|
async def create_generate_task(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
images=None,
|
images=None,
|
||||||
seed=1,
|
seed=1,
|
||||||
material="PBR",
|
material="PBR",
|
||||||
quality_override=18000,
|
quality_override=18000,
|
||||||
tier="Regular",
|
tier="Regular",
|
||||||
mesh_mode="Quad",
|
mesh_mode="Quad",
|
||||||
TAPose = False,
|
ta_pose: bool = False,
|
||||||
auth_kwargs: Optional[dict[str, str]] = None,
|
|
||||||
):
|
):
|
||||||
if images is None:
|
if images is None:
|
||||||
raise Exception("Rodin 3D generate requires at least 1 image.")
|
raise Exception("Rodin 3D generate requires at least 1 image.")
|
||||||
if len(images) > 5:
|
if len(images) > 5:
|
||||||
raise Exception("Rodin 3D generate requires up to 5 image.")
|
raise Exception("Rodin 3D generate requires up to 5 image.")
|
||||||
|
|
||||||
path = "/proxy/rodin/api/v2/rodin"
|
response = await sync_op(
|
||||||
operation = SynchronousOperation(
|
cls,
|
||||||
endpoint=ApiEndpoint(
|
ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"),
|
||||||
path=path,
|
response_model=Rodin3DGenerateResponse,
|
||||||
method=HttpMethod.POST,
|
data=Rodin3DGenerateRequest(
|
||||||
request_model=Rodin3DGenerateRequest,
|
|
||||||
response_model=Rodin3DGenerateResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DGenerateRequest(
|
|
||||||
seed=seed,
|
seed=seed,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
material=material,
|
material=material,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
TAPose=TAPose,
|
TAPose=ta_pose,
|
||||||
),
|
),
|
||||||
files=[
|
files=[
|
||||||
(
|
(
|
||||||
@ -159,11 +152,8 @@ async def create_generate_task(
|
|||||||
for image in images if image is not None
|
for image in images if image is not None
|
||||||
],
|
],
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await operation.execute()
|
|
||||||
|
|
||||||
if hasattr(response, "error"):
|
if hasattr(response, "error"):
|
||||||
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
|
||||||
logging.error(error_message)
|
logging.error(error_message)
|
||||||
@ -172,111 +162,83 @@ async def create_generate_task(
|
|||||||
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
|
||||||
subscription_key = response.jobs.subscription_key
|
subscription_key = response.jobs.subscription_key
|
||||||
task_uuid = response.uuid
|
task_uuid = response.uuid
|
||||||
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
|
logging.info("[ Rodin3D API - Submit Jobs ] UUID: %s", task_uuid)
|
||||||
return task_uuid, subscription_key
|
return task_uuid, subscription_key
|
||||||
|
|
||||||
|
|
||||||
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str:
|
||||||
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
all_done = all(job.status == JobStatus.Done for job in response.jobs)
|
||||||
status_list = [str(job.status) for job in response.jobs]
|
status_list = [str(job.status) for job in response.jobs]
|
||||||
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Status: %s", status_list)
|
||||||
if any(job.status == JobStatus.Failed for job in response.jobs):
|
if any(job.status == JobStatus.Failed for job in response.jobs):
|
||||||
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
|
logging.error("[ Rodin3D API - CheckStatus ] Generate Failed: %s, Please try again.", status_list)
|
||||||
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
|
||||||
if all_done:
|
if all_done:
|
||||||
return "DONE"
|
return "DONE"
|
||||||
return "Generating"
|
return "Generating"
|
||||||
|
|
||||||
|
def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]:
|
||||||
|
if not response.jobs:
|
||||||
|
return None
|
||||||
|
completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done)
|
||||||
|
return int((completed_count / len(response.jobs)) * 100)
|
||||||
|
|
||||||
async def poll_for_task_status(
|
|
||||||
subscription_key, auth_kwargs: Optional[dict[str, str]] = None,
|
async def poll_for_task_status(subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse:
|
||||||
) -> Rodin3DCheckStatusResponse:
|
|
||||||
poll_operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/rodin/api/v2/status",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=Rodin3DCheckStatusRequest,
|
|
||||||
response_model=Rodin3DCheckStatusResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
|
||||||
completed_statuses=["DONE"],
|
|
||||||
failed_statuses=["FAILED"],
|
|
||||||
status_extractor=check_rodin_status,
|
|
||||||
poll_interval=3.0,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
|
||||||
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
|
||||||
return await poll_operation.execute()
|
return await poll_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"),
|
||||||
async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse:
|
response_model=Rodin3DCheckStatusResponse,
|
||||||
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
data=Rodin3DCheckStatusRequest(subscription_key=subscription_key),
|
||||||
operation = SynchronousOperation(
|
status_extractor=check_rodin_status,
|
||||||
endpoint=ApiEndpoint(
|
progress_extractor=extract_progress,
|
||||||
path="/proxy/rodin/api/v2/download",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=Rodin3DDownloadRequest,
|
|
||||||
response_model=Rodin3DDownloadResponse,
|
|
||||||
),
|
|
||||||
request=Rodin3DDownloadRequest(task_uuid=uuid),
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
return await operation.execute()
|
|
||||||
|
|
||||||
|
|
||||||
async def download_files(url_list, task_uuid):
|
async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse:
|
||||||
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}")
|
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
|
||||||
|
return await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"),
|
||||||
|
response_model=Rodin3DDownloadResponse,
|
||||||
|
data=Rodin3DDownloadRequest(task_uuid=uuid),
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_files(url_list, task_uuid: str):
|
||||||
|
result_folder_name = f"Rodin3D_{task_uuid}"
|
||||||
|
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
model_file_path = None
|
model_file_path = None
|
||||||
async with aiohttp.ClientSession() as session:
|
for i in url_list.list:
|
||||||
for i in url_list.list:
|
file_path = os.path.join(save_path, i.name)
|
||||||
url = i.url
|
if file_path.endswith(".glb"):
|
||||||
file_name = i.name
|
model_file_path = os.path.join(result_folder_name, i.name)
|
||||||
file_path = os.path.join(save_path, file_name)
|
await download_url_to_bytesio(i.url, file_path)
|
||||||
if file_path.endswith(".glb"):
|
|
||||||
model_file_path = file_path
|
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
|
|
||||||
max_retries = 5
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with session.get(url) as resp:
|
|
||||||
resp.raise_for_status()
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
async for chunk in resp.content.iter_chunked(32 * 1024):
|
|
||||||
f.write(chunk)
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logging.info("Retrying...")
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
else:
|
|
||||||
logging.info(
|
|
||||||
"[ Rodin3D API - download_files ] Failed to download %s after %s attempts.",
|
|
||||||
file_path,
|
|
||||||
max_retries,
|
|
||||||
)
|
|
||||||
return model_file_path
|
return model_file_path
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Regular(comfy_io.ComfyNode):
|
class Rodin3D_Regular(IO.ComfyNode):
|
||||||
"""Generate 3D Assets using Rodin API"""
|
"""Generate 3D Assets using Rodin API"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Rodin3D_Regular",
|
node_id="Rodin3D_Regular",
|
||||||
display_name="Rodin 3D Generate - Regular Generate",
|
display_name="Rodin 3D Generate - Regular Generate",
|
||||||
category="api node/3d/Rodin",
|
category="api node/3d/Rodin",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("Images"),
|
IO.Image.Input("Images"),
|
||||||
*COMMON_PARAMETERS,
|
*COMMON_PARAMETERS,
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -288,51 +250,48 @@ class Rodin3D_Regular(comfy_io.ComfyNode):
|
|||||||
Seed,
|
Seed,
|
||||||
Material_Type,
|
Material_Type,
|
||||||
Polygon_count,
|
Polygon_count,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Regular"
|
tier = "Regular"
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Detail(comfy_io.ComfyNode):
|
class Rodin3D_Detail(IO.ComfyNode):
|
||||||
"""Generate 3D Assets using Rodin API"""
|
"""Generate 3D Assets using Rodin API"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Rodin3D_Detail",
|
node_id="Rodin3D_Detail",
|
||||||
display_name="Rodin 3D Generate - Detail Generate",
|
display_name="Rodin 3D Generate - Detail Generate",
|
||||||
category="api node/3d/Rodin",
|
category="api node/3d/Rodin",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("Images"),
|
IO.Image.Input("Images"),
|
||||||
*COMMON_PARAMETERS,
|
*COMMON_PARAMETERS,
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -344,51 +303,48 @@ class Rodin3D_Detail(comfy_io.ComfyNode):
|
|||||||
Seed,
|
Seed,
|
||||||
Material_Type,
|
Material_Type,
|
||||||
Polygon_count,
|
Polygon_count,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Detail"
|
tier = "Detail"
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Smooth(comfy_io.ComfyNode):
|
class Rodin3D_Smooth(IO.ComfyNode):
|
||||||
"""Generate 3D Assets using Rodin API"""
|
"""Generate 3D Assets using Rodin API"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Rodin3D_Smooth",
|
node_id="Rodin3D_Smooth",
|
||||||
display_name="Rodin 3D Generate - Smooth Generate",
|
display_name="Rodin 3D Generate - Smooth Generate",
|
||||||
category="api node/3d/Rodin",
|
category="api node/3d/Rodin",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("Images"),
|
IO.Image.Input("Images"),
|
||||||
*COMMON_PARAMETERS,
|
*COMMON_PARAMETERS,
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -400,58 +356,54 @@ class Rodin3D_Smooth(comfy_io.ComfyNode):
|
|||||||
Seed,
|
Seed,
|
||||||
Material_Type,
|
Material_Type,
|
||||||
Polygon_count,
|
Polygon_count,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Smooth"
|
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier="Smooth",
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Sketch(comfy_io.ComfyNode):
|
class Rodin3D_Sketch(IO.ComfyNode):
|
||||||
"""Generate 3D Assets using Rodin API"""
|
"""Generate 3D Assets using Rodin API"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Rodin3D_Sketch",
|
node_id="Rodin3D_Sketch",
|
||||||
display_name="Rodin 3D Generate - Sketch Generate",
|
display_name="Rodin 3D Generate - Sketch Generate",
|
||||||
category="api node/3d/Rodin",
|
category="api node/3d/Rodin",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("Images"),
|
IO.Image.Input("Images"),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"Seed",
|
"Seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=65535,
|
max=65535,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -461,68 +413,61 @@ class Rodin3D_Sketch(comfy_io.ComfyNode):
|
|||||||
cls,
|
cls,
|
||||||
Images,
|
Images,
|
||||||
Seed,
|
Seed,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Sketch"
|
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
material_type = "PBR"
|
|
||||||
quality_override = 18000
|
|
||||||
mesh_mode = "Quad"
|
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=material_type,
|
material="PBR",
|
||||||
quality_override=quality_override,
|
quality_override=18000,
|
||||||
tier=tier,
|
tier="Sketch",
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode="Quad",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3D_Gen2(comfy_io.ComfyNode):
|
class Rodin3D_Gen2(IO.ComfyNode):
|
||||||
"""Generate 3D Assets using Rodin API"""
|
"""Generate 3D Assets using Rodin API"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> comfy_io.Schema:
|
def define_schema(cls) -> IO.Schema:
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Rodin3D_Gen2",
|
node_id="Rodin3D_Gen2",
|
||||||
display_name="Rodin 3D Generate - Gen-2 Generate",
|
display_name="Rodin 3D Generate - Gen-2 Generate",
|
||||||
category="api node/3d/Rodin",
|
category="api node/3d/Rodin",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("Images"),
|
IO.Image.Input("Images"),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"Seed",
|
"Seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=65535,
|
max=65535,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"Polygon_count",
|
"Polygon_count",
|
||||||
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
|
||||||
default="500K-Triangle",
|
default="500K-Triangle",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input("TAPose", default=False),
|
IO.Boolean.Input("TAPose", default=False),
|
||||||
],
|
],
|
||||||
outputs=[comfy_io.String.Output(display_name="3D Model Path")],
|
outputs=[IO.String.Output(display_name="3D Model Path")],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -535,37 +480,33 @@ class Rodin3D_Gen2(comfy_io.ComfyNode):
|
|||||||
Material_Type,
|
Material_Type,
|
||||||
Polygon_count,
|
Polygon_count,
|
||||||
TAPose,
|
TAPose,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
tier = "Gen-2"
|
tier = "Gen-2"
|
||||||
num_images = Images.shape[0]
|
num_images = Images.shape[0]
|
||||||
m_images = []
|
m_images = []
|
||||||
for i in range(num_images):
|
for i in range(num_images):
|
||||||
m_images.append(Images[i])
|
m_images.append(Images[i])
|
||||||
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
mesh_mode, quality_override = get_quality_mode(Polygon_count)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
task_uuid, subscription_key = await create_generate_task(
|
task_uuid, subscription_key = await create_generate_task(
|
||||||
|
cls,
|
||||||
images=m_images,
|
images=m_images,
|
||||||
seed=Seed,
|
seed=Seed,
|
||||||
material=Material_Type,
|
material=Material_Type,
|
||||||
quality_override=quality_override,
|
quality_override=quality_override,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
mesh_mode=mesh_mode,
|
mesh_mode=mesh_mode,
|
||||||
TAPose=TAPose,
|
ta_pose=TAPose,
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
await poll_for_task_status(subscription_key, auth_kwargs=auth)
|
await poll_for_task_status(subscription_key, cls)
|
||||||
download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth)
|
download_list = await get_rodin_download_list(task_uuid, cls)
|
||||||
model = await download_files(download_list, task_uuid)
|
model = await download_files(download_list, task_uuid)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(model)
|
return IO.NodeOutput(model)
|
||||||
|
|
||||||
|
|
||||||
class Rodin3DExtension(ComfyExtension):
|
class Rodin3DExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
Rodin3D_Regular,
|
Rodin3D_Regular,
|
||||||
Rodin3D_Detail,
|
Rodin3D_Detail,
|
||||||
|
|||||||
@ -11,7 +11,7 @@ User Guides:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Union, Optional, Any
|
from typing import Union, Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
@ -21,7 +21,6 @@ from comfy_api_nodes.apis import (
|
|||||||
RunwayImageToVideoRequest,
|
RunwayImageToVideoRequest,
|
||||||
RunwayImageToVideoResponse,
|
RunwayImageToVideoResponse,
|
||||||
RunwayTaskStatusResponse as TaskStatusResponse,
|
RunwayTaskStatusResponse as TaskStatusResponse,
|
||||||
RunwayTaskStatusEnum as TaskStatus,
|
|
||||||
RunwayModelEnum as Model,
|
RunwayModelEnum as Model,
|
||||||
RunwayDurationEnum as Duration,
|
RunwayDurationEnum as Duration,
|
||||||
RunwayAspectRatioEnum as AspectRatio,
|
RunwayAspectRatioEnum as AspectRatio,
|
||||||
@ -33,23 +32,20 @@ from comfy_api_nodes.apis import (
|
|||||||
ReferenceImage,
|
ReferenceImage,
|
||||||
RunwayTextToImageAspectRatioEnum,
|
RunwayTextToImageAspectRatioEnum,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
upload_images_to_comfyapi,
|
|
||||||
download_url_to_video_output,
|
|
||||||
image_tensor_pair_to_batch,
|
image_tensor_pair_to_batch,
|
||||||
validate_string,
|
validate_string,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_image_aspect_ratio,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
download_url_to_video_output,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
|
ApiEndpoint,
|
||||||
|
sync_op,
|
||||||
|
poll_op,
|
||||||
)
|
)
|
||||||
from comfy_api.input_impl import VideoFromFile
|
from comfy_api.input_impl import VideoFromFile
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio
|
|
||||||
|
|
||||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||||
@ -91,31 +87,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def poll_until_finished(
|
|
||||||
auth_kwargs: dict[str, str],
|
|
||||||
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
|
|
||||||
estimated_duration: Optional[int] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> TaskStatusResponse:
|
|
||||||
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
|
|
||||||
return await PollingOperation(
|
|
||||||
poll_endpoint=api_endpoint,
|
|
||||||
completed_statuses=[
|
|
||||||
TaskStatus.SUCCEEDED.value,
|
|
||||||
],
|
|
||||||
failed_statuses=[
|
|
||||||
TaskStatus.FAILED.value,
|
|
||||||
TaskStatus.CANCELLED.value,
|
|
||||||
],
|
|
||||||
status_extractor=lambda response: response.status.value,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=get_video_url_from_task_status,
|
|
||||||
estimated_duration=estimated_duration,
|
|
||||||
node_id=node_id,
|
|
||||||
progress_extractor=extract_progress_from_task_status,
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
|
|
||||||
def extract_progress_from_task_status(
|
def extract_progress_from_task_status(
|
||||||
response: TaskStatusResponse,
|
response: TaskStatusResponse,
|
||||||
) -> Union[float, None]:
|
) -> Union[float, None]:
|
||||||
@ -132,42 +103,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
|||||||
|
|
||||||
|
|
||||||
async def get_response(
|
async def get_response(
|
||||||
task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None
|
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||||
) -> TaskStatusResponse:
|
) -> TaskStatusResponse:
|
||||||
"""Poll the task status until it is finished then get the response."""
|
"""Poll the task status until it is finished then get the response."""
|
||||||
return await poll_until_finished(
|
return await poll_op(
|
||||||
auth_kwargs,
|
cls,
|
||||||
ApiEndpoint(
|
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||||
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
|
response_model=TaskStatusResponse,
|
||||||
method=HttpMethod.GET,
|
status_extractor=lambda r: r.status.value,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
),
|
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
node_id=node_id,
|
progress_extractor=extract_progress_from_task_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_video(
|
async def generate_video(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
request: RunwayImageToVideoRequest,
|
request: RunwayImageToVideoRequest,
|
||||||
auth_kwargs: dict[str, str],
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
estimated_duration: Optional[int] = None,
|
estimated_duration: Optional[int] = None,
|
||||||
) -> VideoFromFile:
|
) -> VideoFromFile:
|
||||||
initial_operation = SynchronousOperation(
|
initial_response = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path=PATH_IMAGE_TO_VIDEO,
|
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=RunwayImageToVideoResponse,
|
||||||
request_model=RunwayImageToVideoRequest,
|
data=request,
|
||||||
response_model=RunwayImageToVideoResponse,
|
|
||||||
),
|
|
||||||
request=request,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = await initial_operation.execute()
|
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||||
|
|
||||||
final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration)
|
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||||
|
|
||||||
@ -175,55 +136,55 @@ async def generate_video(
|
|||||||
return await download_url_to_video_output(video_url)
|
return await download_url_to_video_output(video_url)
|
||||||
|
|
||||||
|
|
||||||
class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="RunwayImageToVideoNodeGen3a",
|
node_id="RunwayImageToVideoNodeGen3a",
|
||||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||||
category="api node/video/Runway",
|
category="api node/video/Runway",
|
||||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||||
"Before diving in, review these best practices to ensure that "
|
"Before diving in, review these best practices to ensure that "
|
||||||
"your input selections will set your generation up for success: "
|
"your input selections will set your generation up for success: "
|
||||||
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
"https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt for the generation",
|
tooltip="Text prompt for the generation",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"start_frame",
|
"start_frame",
|
||||||
tooltip="Start frame to be used for the video",
|
tooltip="Start frame to be used for the video",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[model.value for model in Duration],
|
options=Duration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
options=RunwayGen3aAspectRatio,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Random seed for generation",
|
tooltip="Random seed for generation",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -236,25 +197,21 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
|||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
start_frame,
|
start_frame,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await generate_video(
|
await generate_video(
|
||||||
|
cls,
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -262,68 +219,62 @@ class RunwayImageToVideoNodeGen3a(comfy_io.ComfyNode):
|
|||||||
duration=Duration(duration),
|
duration=Duration(duration),
|
||||||
ratio=AspectRatio(ratio),
|
ratio=AspectRatio(ratio),
|
||||||
promptImage=RunwayPromptImageObject(
|
promptImage=RunwayPromptImageObject(
|
||||||
root=[
|
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||||
RunwayPromptImageDetailedObject(
|
|
||||||
uri=str(download_urls[0]), position="first"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="RunwayImageToVideoNodeGen4",
|
node_id="RunwayImageToVideoNodeGen4",
|
||||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||||
category="api node/video/Runway",
|
category="api node/video/Runway",
|
||||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||||
"Before diving in, review these best practices to ensure that "
|
"Before diving in, review these best practices to ensure that "
|
||||||
"your input selections will set your generation up for success: "
|
"your input selections will set your generation up for success: "
|
||||||
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
"https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt for the generation",
|
tooltip="Text prompt for the generation",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"start_frame",
|
"start_frame",
|
||||||
tooltip="Start frame to be used for the video",
|
tooltip="Start frame to be used for the video",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[model.value for model in Duration],
|
options=Duration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=[model.value for model in RunwayGen4TurboAspectRatio],
|
options=RunwayGen4TurboAspectRatio,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Random seed for generation",
|
tooltip="Random seed for generation",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -336,25 +287,21 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
|||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
start_frame,
|
start_frame,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return comfy_io.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await generate_video(
|
await generate_video(
|
||||||
|
cls,
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -362,76 +309,70 @@ class RunwayImageToVideoNodeGen4(comfy_io.ComfyNode):
|
|||||||
duration=Duration(duration),
|
duration=Duration(duration),
|
||||||
ratio=AspectRatio(ratio),
|
ratio=AspectRatio(ratio),
|
||||||
promptImage=RunwayPromptImageObject(
|
promptImage=RunwayPromptImageObject(
|
||||||
root=[
|
root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")]
|
||||||
RunwayPromptImageDetailedObject(
|
|
||||||
uri=str(download_urls[0]), position="first"
|
|
||||||
)
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="RunwayFirstLastFrameNode",
|
node_id="RunwayFirstLastFrameNode",
|
||||||
display_name="Runway First-Last-Frame to Video",
|
display_name="Runway First-Last-Frame to Video",
|
||||||
category="api node/video/Runway",
|
category="api node/video/Runway",
|
||||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||||
"More complex transitions, such as cases where the Last frame is completely different "
|
"More complex transitions, such as cases where the Last frame is completely different "
|
||||||
"from the First frame, may benefit from the longer 10s duration. "
|
"from the First frame, may benefit from the longer 10s duration. "
|
||||||
"This would give the generation more time to smoothly transition between the two inputs. "
|
"This would give the generation more time to smoothly transition between the two inputs. "
|
||||||
"Before diving in, review these best practices to ensure that your input selections "
|
"Before diving in, review these best practices to ensure that your input selections "
|
||||||
"will set your generation up for success: "
|
"will set your generation up for success: "
|
||||||
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
"https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt for the generation",
|
tooltip="Text prompt for the generation",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"start_frame",
|
"start_frame",
|
||||||
tooltip="Start frame to be used for the video",
|
tooltip="Start frame to be used for the video",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"end_frame",
|
"end_frame",
|
||||||
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
|
tooltip="End frame to be used for the video. Supported for gen3a_turbo only.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[model.value for model in Duration],
|
options=Duration,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=[model.value for model in RunwayGen3aAspectRatio],
|
options=RunwayGen3aAspectRatio,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967295,
|
max=4294967295,
|
||||||
step=1,
|
step=1,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Random seed for generation",
|
tooltip="Random seed for generation",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -445,30 +386,26 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
|||||||
duration: str,
|
duration: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(start_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
validate_image_dimensions(end_frame, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(start_frame, (1, 2), (2, 1))
|
||||||
validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(end_frame, (1, 2), (2, 1))
|
||||||
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
stacked_input_images,
|
stacked_input_images,
|
||||||
max_images=2,
|
max_images=2,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
if len(download_urls) != 2:
|
if len(download_urls) != 2:
|
||||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||||
|
|
||||||
return comfy_io.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await generate_video(
|
await generate_video(
|
||||||
|
cls,
|
||||||
RunwayImageToVideoRequest(
|
RunwayImageToVideoRequest(
|
||||||
promptText=prompt,
|
promptText=prompt,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
@ -477,56 +414,50 @@ class RunwayFirstLastFrameNode(comfy_io.ComfyNode):
|
|||||||
ratio=AspectRatio(ratio),
|
ratio=AspectRatio(ratio),
|
||||||
promptImage=RunwayPromptImageObject(
|
promptImage=RunwayPromptImageObject(
|
||||||
root=[
|
root=[
|
||||||
RunwayPromptImageDetailedObject(
|
RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"),
|
||||||
uri=str(download_urls[0]), position="first"
|
RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"),
|
||||||
),
|
|
||||||
RunwayPromptImageDetailedObject(
|
|
||||||
uri=str(download_urls[1]), position="last"
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunwayTextToImageNode(comfy_io.ComfyNode):
|
class RunwayTextToImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="RunwayTextToImageNode",
|
node_id="RunwayTextToImageNode",
|
||||||
display_name="Runway Text to Image",
|
display_name="Runway Text to Image",
|
||||||
category="api node/image/Runway",
|
category="api node/image/Runway",
|
||||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||||
"You can also include reference image to guide the generation.",
|
"You can also include reference image to guide the generation.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text prompt for the generation",
|
tooltip="Text prompt for the generation",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
|
options=[model.value for model in RunwayTextToImageAspectRatioEnum],
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"reference_image",
|
"reference_image",
|
||||||
tooltip="Optional reference image to guide the generation",
|
tooltip="Optional reference image to guide the generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -537,63 +468,48 @@ class RunwayTextToImageNode(comfy_io.ComfyNode):
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
ratio: str,
|
ratio: str,
|
||||||
reference_image: Optional[torch.Tensor] = None,
|
reference_image: Optional[torch.Tensor] = None,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, min_length=1)
|
validate_string(prompt, min_length=1)
|
||||||
|
|
||||||
auth_kwargs = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prepare reference images if provided
|
# Prepare reference images if provided
|
||||||
reference_images = None
|
reference_images = None
|
||||||
if reference_image is not None:
|
if reference_image is not None:
|
||||||
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
validate_image_dimensions(reference_image, max_width=7999, max_height=7999)
|
||||||
validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0)
|
validate_image_aspect_ratio(reference_image, (1, 2), (2, 1))
|
||||||
download_urls = await upload_images_to_comfyapi(
|
download_urls = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
reference_image,
|
reference_image,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
|
||||||
|
|
||||||
request = RunwayTextToImageRequest(
|
initial_response = await sync_op(
|
||||||
promptText=prompt,
|
cls,
|
||||||
model=Model4.gen4_image,
|
endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"),
|
||||||
ratio=ratio,
|
response_model=RunwayTextToImageResponse,
|
||||||
referenceImages=reference_images,
|
data=RunwayTextToImageRequest(
|
||||||
)
|
promptText=prompt,
|
||||||
|
model=Model4.gen4_image,
|
||||||
initial_operation = SynchronousOperation(
|
ratio=ratio,
|
||||||
endpoint=ApiEndpoint(
|
referenceImages=reference_images,
|
||||||
path=PATH_TEXT_TO_IMAGE,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=RunwayTextToImageRequest,
|
|
||||||
response_model=RunwayTextToImageResponse,
|
|
||||||
),
|
),
|
||||||
request=request,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = await initial_operation.execute()
|
|
||||||
|
|
||||||
# Poll for completion
|
|
||||||
final_response = await get_response(
|
final_response = await get_response(
|
||||||
|
cls,
|
||||||
initial_response.id,
|
initial_response.id,
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||||
)
|
)
|
||||||
if not final_response.output:
|
if not final_response.output:
|
||||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||||
|
|
||||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||||
|
|
||||||
|
|
||||||
class RunwayExtension(ComfyExtension):
|
class RunwayExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
RunwayFirstLastFrameNode,
|
RunwayFirstLastFrameNode,
|
||||||
RunwayImageToVideoNodeGen3a,
|
RunwayImageToVideoNodeGen3a,
|
||||||
@ -601,5 +517,6 @@ class RunwayExtension(ComfyExtension):
|
|||||||
RunwayTextToImageNode,
|
RunwayTextToImageNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> RunwayExtension:
|
async def comfy_entrypoint() -> RunwayExtension:
|
||||||
return RunwayExtension()
|
return RunwayExtension()
|
||||||
|
|||||||
@ -1,23 +1,20 @@
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from typing_extensions import override
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util.validation_utils import get_number_of_images
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
|
get_number_of_images,
|
||||||
|
poll_op,
|
||||||
|
sync_op,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sora2GenerationRequest(BaseModel):
|
class Sora2GenerationRequest(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
@ -31,27 +28,27 @@ class Sora2GenerationResponse(BaseModel):
|
|||||||
status: Optional[str] = Field(None)
|
status: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIVideoSora2(comfy_io.ComfyNode):
|
class OpenAIVideoSora2(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="OpenAIVideoSora2",
|
node_id="OpenAIVideoSora2",
|
||||||
display_name="OpenAI Sora - Video",
|
display_name="OpenAI Sora - Video",
|
||||||
category="api node/video/Sora",
|
category="api node/video/Sora",
|
||||||
description="OpenAI video and audio generation.",
|
description="OpenAI video and audio generation.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["sora-2", "sora-2-pro"],
|
options=["sora-2", "sora-2-pro"],
|
||||||
default="sora-2",
|
default="sora-2",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Guiding text; may be empty if an input image is present.",
|
tooltip="Guiding text; may be empty if an input image is present.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"size",
|
"size",
|
||||||
options=[
|
options=[
|
||||||
"720x1280",
|
"720x1280",
|
||||||
@ -61,35 +58,35 @@ class OpenAIVideoSora2(comfy_io.ComfyNode):
|
|||||||
],
|
],
|
||||||
default="1280x720",
|
default="1280x720",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"duration",
|
"duration",
|
||||||
options=[4, 8, 12],
|
options=[4, 8, 12],
|
||||||
default=8,
|
default=8,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Seed to determine if node should re-run; "
|
tooltip="Seed to determine if node should re-run; "
|
||||||
"actual results are nondeterministic regardless of seed.",
|
"actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -111,61 +108,40 @@ class OpenAIVideoSora2(comfy_io.ComfyNode):
|
|||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Currently only one input image is supported.")
|
raise ValueError("Currently only one input image is supported.")
|
||||||
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
|
files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")}
|
||||||
auth = {
|
initial_response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"),
|
||||||
}
|
data=Sora2GenerationRequest(
|
||||||
payload = Sora2GenerationRequest(
|
model=model,
|
||||||
model=model,
|
prompt=prompt,
|
||||||
prompt=prompt,
|
seconds=str(duration),
|
||||||
seconds=str(duration),
|
size=size,
|
||||||
size=size,
|
|
||||||
)
|
|
||||||
initial_operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/openai/v1/videos",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=Sora2GenerationRequest,
|
|
||||||
response_model=Sora2GenerationResponse
|
|
||||||
),
|
),
|
||||||
request=payload,
|
|
||||||
files=files_input,
|
files=files_input,
|
||||||
auth_kwargs=auth,
|
response_model=Sora2GenerationResponse,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
)
|
)
|
||||||
initial_response = await initial_operation.execute()
|
|
||||||
if initial_response.error:
|
if initial_response.error:
|
||||||
raise Exception(initial_response.error.message)
|
raise Exception(initial_response.error["message"])
|
||||||
|
|
||||||
model_time_multiplier = 1 if model == "sora-2" else 2
|
model_time_multiplier = 1 if model == "sora-2" else 2
|
||||||
poll_operation = PollingOperation(
|
await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/openai/v1/videos/{initial_response.id}",
|
poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=Sora2GenerationResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=Sora2GenerationResponse
|
|
||||||
),
|
|
||||||
completed_statuses=["completed"],
|
|
||||||
failed_statuses=["failed"],
|
|
||||||
status_extractor=lambda x: x.status,
|
status_extractor=lambda x: x.status,
|
||||||
auth_kwargs=auth,
|
|
||||||
poll_interval=8.0,
|
poll_interval=8.0,
|
||||||
max_poll_attempts=160,
|
max_poll_attempts=160,
|
||||||
node_id=cls.hidden.unique_id,
|
estimated_duration=int(45 * (duration / 4) * model_time_multiplier),
|
||||||
estimated_duration=45 * (duration / 4) * model_time_multiplier,
|
|
||||||
)
|
)
|
||||||
await poll_operation.execute()
|
return IO.NodeOutput(
|
||||||
return comfy_io.NodeOutput(
|
await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls),
|
||||||
await download_url_to_video_output(
|
|
||||||
f"/proxy/openai/v1/videos/{initial_response.id}/content",
|
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OpenAISoraExtension(ComfyExtension):
|
class OpenAISoraExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
OpenAIVideoSora2,
|
OpenAIVideoSora2,
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from inspect import cleandoc
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
from comfy_api.latest import ComfyExtension, Input, IO
|
||||||
from comfy_api_nodes.apis.stability_api import (
|
from comfy_api_nodes.apis.stability_api import (
|
||||||
StabilityUpscaleConservativeRequest,
|
StabilityUpscaleConservativeRequest,
|
||||||
StabilityUpscaleCreativeRequest,
|
StabilityUpscaleCreativeRequest,
|
||||||
@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import (
|
|||||||
StabilityAudioInpaintRequest,
|
StabilityAudioInpaintRequest,
|
||||||
StabilityAudioResponse,
|
StabilityAudioResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
validate_audio_duration,
|
||||||
HttpMethod,
|
validate_string,
|
||||||
SynchronousOperation,
|
audio_input_to_mp3,
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
bytesio_to_image_tensor,
|
bytesio_to_image_tensor,
|
||||||
tensor_to_bytesio,
|
tensor_to_bytesio,
|
||||||
validate_string,
|
|
||||||
audio_bytes_to_audio_input,
|
audio_bytes_to_audio_input,
|
||||||
audio_input_to_mp3,
|
sync_op,
|
||||||
|
poll_op,
|
||||||
|
ApiEndpoint,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.util.validation_utils import validate_audio_duration
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import base64
|
import base64
|
||||||
@ -56,20 +52,20 @@ def get_async_dummy_status(x: StabilityResultsGetResponse):
|
|||||||
return StabilityPollStatus.in_progress
|
return StabilityPollStatus.in_progress
|
||||||
|
|
||||||
|
|
||||||
class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Generates images synchronously based on prompt and resolution.
|
Generates images synchronously based on prompt and resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityStableImageUltraNode",
|
node_id="StabilityStableImageUltraNode",
|
||||||
display_name="Stability AI Stable Image Ultra",
|
display_name="Stability AI Stable Image Ultra",
|
||||||
category="api node/image/Stability AI",
|
category="api node/image/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
@ -80,39 +76,39 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
"is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
|
||||||
"would convey a sky that was blue and green, but more green than blue.",
|
"would convey a sky that was blue and green, but more green than blue.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[x.value for x in StabilityAspectRatio],
|
options=StabilityAspectRatio,
|
||||||
default=StabilityAspectRatio.ratio_1_1.value,
|
default=StabilityAspectRatio.ratio_1_1,
|
||||||
tooltip="Aspect ratio of generated image.",
|
tooltip="Aspect ratio of generated image.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"style_preset",
|
"style_preset",
|
||||||
options=get_stability_style_presets(),
|
options=get_stability_style_presets(),
|
||||||
tooltip="Optional desired style of generated image.",
|
tooltip="Optional desired style of generated image.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for creating the noise.",
|
tooltip="The random seed used for creating the noise.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
tooltip="A blurb of text describing what you do not wish to see in the output image. This is an advanced feature.",
|
||||||
force_input=True,
|
force_input=True,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"image_denoise",
|
"image_denoise",
|
||||||
default=0.5,
|
default=0.5,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
@ -123,12 +119,12 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -143,7 +139,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
image: Optional[torch.Tensor] = None,
|
image: Optional[torch.Tensor] = None,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
image_denoise: Optional[float] = 0.5,
|
image_denoise: Optional[float] = 0.5,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
# prepare image binary if image present
|
# prepare image binary if image present
|
||||||
image_binary = None
|
image_binary = None
|
||||||
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStableUltraRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/ultra",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStableUltraRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStableUltraRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
|
||||||
@ -193,44 +179,44 @@ class StabilityStableImageUltraNode(comfy_io.ComfyNode):
|
|||||||
image_data = base64.b64decode(response_api.image)
|
image_data = base64.b64decode(response_api.image)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
return comfy_io.NodeOutput(returned_image)
|
return IO.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Generates images synchronously based on prompt and resolution.
|
Generates images synchronously based on prompt and resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityStableImageSD_3_5Node",
|
node_id="StabilityStableImageSD_3_5Node",
|
||||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||||
category="api node/image/Stability AI",
|
category="api node/image/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[x.value for x in Stability_SD3_5_Model],
|
options=Stability_SD3_5_Model,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[x.value for x in StabilityAspectRatio],
|
options=StabilityAspectRatio,
|
||||||
default=StabilityAspectRatio.ratio_1_1.value,
|
default=StabilityAspectRatio.ratio_1_1,
|
||||||
tooltip="Aspect ratio of generated image.",
|
tooltip="Aspect ratio of generated image.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"style_preset",
|
"style_preset",
|
||||||
options=get_stability_style_presets(),
|
options=get_stability_style_presets(),
|
||||||
tooltip="Optional desired style of generated image.",
|
tooltip="Optional desired style of generated image.",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"cfg_scale",
|
"cfg_scale",
|
||||||
default=4.0,
|
default=4.0,
|
||||||
min=1.0,
|
min=1.0,
|
||||||
@ -238,28 +224,28 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
step=0.1,
|
step=0.1,
|
||||||
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
tooltip="How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for creating the noise.",
|
tooltip="The random seed used for creating the noise.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||||
force_input=True,
|
force_input=True,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"image_denoise",
|
"image_denoise",
|
||||||
default=0.5,
|
default=0.5,
|
||||||
min=0.0,
|
min=0.0,
|
||||||
@ -270,12 +256,12 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -292,7 +278,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
image: Optional[torch.Tensor] = None,
|
image: Optional[torch.Tensor] = None,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
image_denoise: Optional[float] = 0.5,
|
image_denoise: Optional[float] = 0.5,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
# prepare image binary if image present
|
# prepare image binary if image present
|
||||||
image_binary = None
|
image_binary = None
|
||||||
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityStable3_5Request(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/generate/sd3",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityStable3_5Request,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityStable3_5Request(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
|
||||||
@ -348,30 +324,30 @@ class StabilityStableImageSD_3_5Node(comfy_io.ComfyNode):
|
|||||||
image_data = base64.b64decode(response_api.image)
|
image_data = base64.b64decode(response_api.image)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
return comfy_io.NodeOutput(returned_image)
|
return IO.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Upscale image with minimal alterations to 4K resolution.
|
Upscale image with minimal alterations to 4K resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityUpscaleConservativeNode",
|
node_id="StabilityUpscaleConservativeNode",
|
||||||
display_name="Stability AI Upscale Conservative",
|
display_name="Stability AI Upscale Conservative",
|
||||||
category="api node/image/Stability AI",
|
category="api node/image/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"creativity",
|
"creativity",
|
||||||
default=0.35,
|
default=0.35,
|
||||||
min=0.2,
|
min=0.2,
|
||||||
@ -379,17 +355,17 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for creating the noise.",
|
tooltip="The random seed used for creating the noise.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||||
@ -398,12 +374,12 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -416,7 +392,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
creativity: float,
|
creativity: float,
|
||||||
seed: int,
|
seed: int,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
|
|
||||||
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
data=StabilityUpscaleConservativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleConservativeRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleConservativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
|
||||||
@ -457,30 +423,30 @@ class StabilityUpscaleConservativeNode(comfy_io.ComfyNode):
|
|||||||
image_data = base64.b64decode(response_api.image)
|
image_data = base64.b64decode(response_api.image)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
return comfy_io.NodeOutput(returned_image)
|
return IO.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Upscale image with minimal alterations to 4K resolution.
|
Upscale image with minimal alterations to 4K resolution.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityUpscaleCreativeNode",
|
node_id="StabilityUpscaleCreativeNode",
|
||||||
display_name="Stability AI Upscale Creative",
|
display_name="Stability AI Upscale Creative",
|
||||||
category="api node/image/Stability AI",
|
category="api node/image/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
tooltip="What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results.",
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"creativity",
|
"creativity",
|
||||||
default=0.3,
|
default=0.3,
|
||||||
min=0.1,
|
min=0.1,
|
||||||
@ -488,22 +454,22 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
step=0.01,
|
step=0.01,
|
||||||
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
tooltip="Controls the likelihood of creating additional details not heavily conditioned by the init image.",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"style_preset",
|
"style_preset",
|
||||||
options=get_stability_style_presets(),
|
options=get_stability_style_presets(),
|
||||||
tooltip="Optional desired style of generated image.",
|
tooltip="Optional desired style of generated image.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for creating the noise.",
|
tooltip="The random seed used for creating the noise.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
default="",
|
default="",
|
||||||
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
tooltip="Keywords of what you do not wish to see in the output image. This is an advanced feature.",
|
||||||
@ -512,12 +478,12 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -531,7 +497,7 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
style_preset: str,
|
style_preset: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
negative_prompt: str = "",
|
negative_prompt: str = "",
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=False)
|
validate_string(prompt, strip_whitespace=False)
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
|
||||||
|
|
||||||
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
|
||||||
}
|
response_model=StabilityAsyncResponse,
|
||||||
|
data=StabilityUpscaleCreativeRequest(
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/creative",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=StabilityUpscaleCreativeRequest,
|
|
||||||
response_model=StabilityAsyncResponse,
|
|
||||||
),
|
|
||||||
request=StabilityUpscaleCreativeRequest(
|
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
creativity=round(creativity,2),
|
creativity=round(creativity,2),
|
||||||
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
operation = PollingOperation(
|
response_poll = await poll_op(
|
||||||
poll_endpoint=ApiEndpoint(
|
cls,
|
||||||
path=f"/proxy/stability/v2beta/results/{response_api.id}",
|
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
|
||||||
method=HttpMethod.GET,
|
response_model=StabilityResultsGetResponse,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityResultsGetResponse,
|
|
||||||
),
|
|
||||||
poll_interval=3,
|
poll_interval=3,
|
||||||
completed_statuses=[StabilityPollStatus.finished],
|
|
||||||
failed_statuses=[StabilityPollStatus.failed],
|
|
||||||
status_extractor=lambda x: get_async_dummy_status(x),
|
status_extractor=lambda x: get_async_dummy_status(x),
|
||||||
auth_kwargs=auth,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
)
|
)
|
||||||
response_poll: StabilityResultsGetResponse = await operation.execute()
|
|
||||||
|
|
||||||
if response_poll.finish_reason != "SUCCESS":
|
if response_poll.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
|
||||||
@ -591,61 +539,50 @@ class StabilityUpscaleCreativeNode(comfy_io.ComfyNode):
|
|||||||
image_data = base64.b64decode(response_poll.result)
|
image_data = base64.b64decode(response_poll.result)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
return comfy_io.NodeOutput(returned_image)
|
return IO.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||||
"""
|
"""
|
||||||
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityUpscaleFastNode",
|
node_id="StabilityUpscaleFastNode",
|
||||||
display_name="Stability AI Upscale Fast",
|
display_name="Stability AI Upscale Fast",
|
||||||
category="api node/image/Stability AI",
|
category="api node/image/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def execute(cls, image: torch.Tensor) -> comfy_io.NodeOutput:
|
async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
|
||||||
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
|
||||||
|
|
||||||
files = {
|
files = {
|
||||||
"image": image_binary
|
"image": image_binary
|
||||||
}
|
}
|
||||||
|
|
||||||
auth = {
|
response_api = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
|
||||||
}
|
response_model=StabilityStableUltraResponse,
|
||||||
|
|
||||||
operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path="/proxy/stability/v2beta/stable-image/upscale/fast",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=StabilityStableUltraResponse,
|
|
||||||
),
|
|
||||||
request=EmptyRequest(),
|
|
||||||
files=files,
|
files=files,
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
|
|
||||||
if response_api.finish_reason != "SUCCESS":
|
if response_api.finish_reason != "SUCCESS":
|
||||||
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
|
||||||
@ -653,26 +590,26 @@ class StabilityUpscaleFastNode(comfy_io.ComfyNode):
|
|||||||
image_data = base64.b64decode(response_api.image)
|
image_data = base64.b64decode(response_api.image)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
|
|
||||||
return comfy_io.NodeOutput(returned_image)
|
return IO.NodeOutput(returned_image)
|
||||||
|
|
||||||
|
|
||||||
class StabilityTextToAudio(comfy_io.ComfyNode):
|
class StabilityTextToAudio(IO.ComfyNode):
|
||||||
"""Generates high-quality music and sound effects from text descriptions."""
|
"""Generates high-quality music and sound effects from text descriptions."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityTextToAudio",
|
node_id="StabilityTextToAudio",
|
||||||
display_name="Stability AI Text To Audio",
|
display_name="Stability AI Text To Audio",
|
||||||
category="api node/audio/Stability AI",
|
category="api node/audio/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["stable-audio-2.5"],
|
options=["stable-audio-2.5"],
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
IO.String.Input("prompt", multiline=True, default=""),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=190,
|
default=190,
|
||||||
min=1,
|
min=1,
|
||||||
@ -681,18 +618,18 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
|
|||||||
tooltip="Controls the duration in seconds of the generated audio.",
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for generation.",
|
tooltip="The random seed used for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=8,
|
default=8,
|
||||||
min=4,
|
min=4,
|
||||||
@ -703,58 +640,50 @@ class StabilityTextToAudio(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Audio.Output(),
|
IO.Audio.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> comfy_io.NodeOutput:
|
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityTextToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
class StabilityAudioToAudio(comfy_io.ComfyNode):
|
class StabilityAudioToAudio(IO.ComfyNode):
|
||||||
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
"""Transforms existing audio samples into new high-quality compositions using text instructions."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityAudioToAudio",
|
node_id="StabilityAudioToAudio",
|
||||||
display_name="Stability AI Audio To Audio",
|
display_name="Stability AI Audio To Audio",
|
||||||
category="api node/audio/Stability AI",
|
category="api node/audio/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["stable-audio-2.5"],
|
options=["stable-audio-2.5"],
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
IO.String.Input("prompt", multiline=True, default=""),
|
||||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=190,
|
default=190,
|
||||||
min=1,
|
min=1,
|
||||||
@ -763,18 +692,18 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
|||||||
tooltip="Controls the duration in seconds of the generated audio.",
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for generation.",
|
tooltip="The random seed used for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=8,
|
default=8,
|
||||||
min=4,
|
min=4,
|
||||||
@ -783,24 +712,24 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
|||||||
tooltip="Controls the number of sampling steps.",
|
tooltip="Controls the number of sampling steps.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Float.Input(
|
IO.Float.Input(
|
||||||
"strength",
|
"strength",
|
||||||
default=1,
|
default=1,
|
||||||
min=0.01,
|
min=0.01,
|
||||||
max=1.0,
|
max=1.0,
|
||||||
step=0.01,
|
step=0.01,
|
||||||
display_mode=comfy_io.NumberDisplay.slider,
|
display_mode=IO.NumberDisplay.slider,
|
||||||
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
tooltip="Parameter controls how much influence the audio parameter has on the generated audio.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Audio.Output(),
|
IO.Audio.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -808,51 +737,43 @@ class StabilityAudioToAudio(comfy_io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
cls, model: str, prompt: str, audio: Input.Audio, duration: int, seed: int, steps: int, strength: float
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
validate_audio_duration(audio, 6, 190)
|
validate_audio_duration(audio, 6, 190)
|
||||||
payload = StabilityAudioToAudioRequest(
|
payload = StabilityAudioToAudioRequest(
|
||||||
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
|
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioToAudioRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs= {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
class StabilityAudioInpaint(comfy_io.ComfyNode):
|
class StabilityAudioInpaint(IO.ComfyNode):
|
||||||
"""Transforms part of existing audio sample using text instructions."""
|
"""Transforms part of existing audio sample using text instructions."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="StabilityAudioInpaint",
|
node_id="StabilityAudioInpaint",
|
||||||
display_name="Stability AI Audio Inpaint",
|
display_name="Stability AI Audio Inpaint",
|
||||||
category="api node/audio/Stability AI",
|
category="api node/audio/Stability AI",
|
||||||
description=cleandoc(cls.__doc__ or ""),
|
description=cleandoc(cls.__doc__ or ""),
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["stable-audio-2.5"],
|
options=["stable-audio-2.5"],
|
||||||
),
|
),
|
||||||
comfy_io.String.Input("prompt", multiline=True, default=""),
|
IO.String.Input("prompt", multiline=True, default=""),
|
||||||
comfy_io.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
IO.Audio.Input("audio", tooltip="Audio must be between 6 and 190 seconds long."),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=190,
|
default=190,
|
||||||
min=1,
|
min=1,
|
||||||
@ -861,18 +782,18 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
tooltip="Controls the duration in seconds of the generated audio.",
|
tooltip="Controls the duration in seconds of the generated audio.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=4294967294,
|
max=4294967294,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="The random seed used for generation.",
|
tooltip="The random seed used for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
default=8,
|
default=8,
|
||||||
min=4,
|
min=4,
|
||||||
@ -881,7 +802,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
tooltip="Controls the number of sampling steps.",
|
tooltip="Controls the number of sampling steps.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"mask_start",
|
"mask_start",
|
||||||
default=30,
|
default=30,
|
||||||
min=0,
|
min=0,
|
||||||
@ -889,7 +810,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
step=1,
|
step=1,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"mask_end",
|
"mask_end",
|
||||||
default=190,
|
default=190,
|
||||||
min=0,
|
min=0,
|
||||||
@ -899,12 +820,12 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Audio.Output(),
|
IO.Audio.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -920,7 +841,7 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
steps: int,
|
steps: int,
|
||||||
mask_start: int,
|
mask_start: int,
|
||||||
mask_end: int,
|
mask_end: int,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, max_length=10000)
|
validate_string(prompt, max_length=10000)
|
||||||
if mask_end <= mask_start:
|
if mask_end <= mask_start:
|
||||||
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
raise ValueError(f"Value of mask_end({mask_end}) should be greater then mask_start({mask_start})")
|
||||||
@ -935,30 +856,22 @@ class StabilityAudioInpaint(comfy_io.ComfyNode):
|
|||||||
mask_start=mask_start,
|
mask_start=mask_start,
|
||||||
mask_end=mask_end,
|
mask_end=mask_end,
|
||||||
)
|
)
|
||||||
operation = SynchronousOperation(
|
response_api = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
|
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=StabilityAudioResponse,
|
||||||
request_model=StabilityAudioInpaintRequest,
|
data=payload,
|
||||||
response_model=StabilityAudioResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
content_type="multipart/form-data",
|
content_type="multipart/form-data",
|
||||||
files={"audio": audio_input_to_mp3(audio)},
|
files={"audio": audio_input_to_mp3(audio)},
|
||||||
auth_kwargs={
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
response_api = await operation.execute()
|
|
||||||
if not response_api.audio:
|
if not response_api.audio:
|
||||||
raise ValueError("No audio file was received in response.")
|
raise ValueError("No audio file was received in response.")
|
||||||
return comfy_io.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
|
||||||
|
|
||||||
|
|
||||||
class StabilityExtension(ComfyExtension):
|
class StabilityExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
StabilityStableImageUltraNode,
|
StabilityStableImageUltraNode,
|
||||||
StabilityStableImageSD_3_5Node,
|
StabilityStableImageSD_3_5Node,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -1,57 +1,35 @@
|
|||||||
import logging
|
|
||||||
import base64
|
import base64
|
||||||
import aiohttp
|
|
||||||
import torch
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
|
||||||
from comfy_api.input_impl.video_types import VideoFromFile
|
from comfy_api.input_impl.video_types import VideoFromFile
|
||||||
from comfy_api_nodes.apis import (
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
VeoGenVidRequest,
|
from comfy_api_nodes.apis.veo_api import (
|
||||||
VeoGenVidResponse,
|
|
||||||
VeoGenVidPollRequest,
|
VeoGenVidPollRequest,
|
||||||
VeoGenVidPollResponse,
|
VeoGenVidPollResponse,
|
||||||
|
VeoGenVidRequest,
|
||||||
|
VeoGenVidResponse,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apis.client import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
poll_op,
|
||||||
PollingOperation,
|
sync_op,
|
||||||
)
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
|
||||||
downscale_image_tensor,
|
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
AVERAGE_DURATION_VIDEO_GEN = 32
|
AVERAGE_DURATION_VIDEO_GEN = 32
|
||||||
|
MODELS_MAP = {
|
||||||
def convert_image_to_base64(image: torch.Tensor):
|
"veo-2.0-generate-001": "veo-2.0-generate-001",
|
||||||
if image is None:
|
"veo-3.1-generate": "veo-3.1-generate-preview",
|
||||||
return None
|
"veo-3.1-fast-generate": "veo-3.1-fast-generate-preview",
|
||||||
|
"veo-3.0-generate-001": "veo-3.0-generate-001",
|
||||||
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
|
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
|
||||||
return tensor_to_base64_string(scaled_image)
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
|
class VeoVideoGenerationNode(IO.ComfyNode):
|
||||||
if (
|
|
||||||
poll_response.response
|
|
||||||
and hasattr(poll_response.response, "videos")
|
|
||||||
and poll_response.response.videos
|
|
||||||
and len(poll_response.response.videos) > 0
|
|
||||||
):
|
|
||||||
video = poll_response.response.videos[0]
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
|
||||||
return str(video.gcsUri)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|
||||||
"""
|
"""
|
||||||
Generates videos from text prompts using Google's Veo API.
|
Generates videos from text prompts using Google's Veo API.
|
||||||
|
|
||||||
@ -61,71 +39,71 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="VeoVideoGenerationNode",
|
node_id="VeoVideoGenerationNode",
|
||||||
display_name="Google Veo 2 Video Generation",
|
display_name="Google Veo 2 Video Generation",
|
||||||
category="api node/video/Veo",
|
category="api node/video/Veo",
|
||||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text description of the video",
|
tooltip="Text description of the video",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=["16:9", "9:16"],
|
options=["16:9", "9:16"],
|
||||||
default="16:9",
|
default="16:9",
|
||||||
tooltip="Aspect ratio of the output video",
|
tooltip="Aspect ratio of the output video",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"enhance_prompt",
|
"enhance_prompt",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance",
|
tooltip="Whether to enhance the prompt with AI assistance",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"person_generation",
|
"person_generation",
|
||||||
options=["ALLOW", "BLOCK"],
|
options=["ALLOW", "BLOCK"],
|
||||||
default="ALLOW",
|
default="ALLOW",
|
||||||
tooltip="Whether to allow generating people in the video",
|
tooltip="Whether to allow generating people in the video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=0xFFFFFFFF,
|
max=0xFFFFFFFF,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Optional reference image to guide video generation",
|
tooltip="Optional reference image to guide video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["veo-2.0-generate-001"],
|
options=["veo-2.0-generate-001"],
|
||||||
default="veo-2.0-generate-001",
|
default="veo-2.0-generate-001",
|
||||||
@ -134,12 +112,12 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -158,21 +136,17 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
model="veo-2.0-generate-001",
|
model="veo-2.0-generate-001",
|
||||||
generate_audio=False,
|
generate_audio=False,
|
||||||
):
|
):
|
||||||
|
model = MODELS_MAP[model]
|
||||||
# Prepare the instances for the request
|
# Prepare the instances for the request
|
||||||
instances = []
|
instances = []
|
||||||
|
|
||||||
instance = {
|
instance = {"prompt": prompt}
|
||||||
"prompt": prompt
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add image if provided
|
# Add image if provided
|
||||||
if image is not None:
|
if image is not None:
|
||||||
image_base64 = convert_image_to_base64(image)
|
image_base64 = tensor_to_base64_string(image)
|
||||||
if image_base64:
|
if image_base64:
|
||||||
instance["image"] = {
|
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
|
||||||
"bytesBase64Encoded": image_base64,
|
|
||||||
"mimeType": "image/png"
|
|
||||||
}
|
|
||||||
|
|
||||||
instances.append(instance)
|
instances.append(instance)
|
||||||
|
|
||||||
@ -190,119 +164,77 @@ class VeoVideoGenerationNode(comfy_io.ComfyNode):
|
|||||||
if seed > 0:
|
if seed > 0:
|
||||||
parameters["seed"] = seed
|
parameters["seed"] = seed
|
||||||
# Only add generateAudio for Veo 3 models
|
# Only add generateAudio for Veo 3 models
|
||||||
if "veo-3.0" in model:
|
if model.find("veo-2.0") == -1:
|
||||||
parameters["generateAudio"] = generate_audio
|
parameters["generateAudio"] = generate_audio
|
||||||
|
|
||||||
auth = {
|
initial_response = await sync_op(
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
cls,
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||||
}
|
response_model=VeoGenVidResponse,
|
||||||
# Initial request to start video generation
|
data=VeoGenVidRequest(
|
||||||
initial_operation = SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/veo/{model}/generate",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=VeoGenVidRequest,
|
|
||||||
response_model=VeoGenVidResponse
|
|
||||||
),
|
|
||||||
request=VeoGenVidRequest(
|
|
||||||
instances=instances,
|
instances=instances,
|
||||||
parameters=parameters
|
parameters=parameters,
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_response = await initial_operation.execute()
|
|
||||||
operation_name = initial_response.name
|
|
||||||
|
|
||||||
logging.info(f"Veo generation started with operation name: {operation_name}")
|
|
||||||
|
|
||||||
# Define status extractor function
|
|
||||||
def status_extractor(response):
|
def status_extractor(response):
|
||||||
# Only return "completed" if the operation is done, regardless of success or failure
|
# Only return "completed" if the operation is done, regardless of success or failure
|
||||||
# We'll check for errors after polling completes
|
# We'll check for errors after polling completes
|
||||||
return "completed" if response.done else "pending"
|
return "completed" if response.done else "pending"
|
||||||
|
|
||||||
# Define progress extractor function
|
poll_response = await poll_op(
|
||||||
def progress_extractor(response):
|
cls,
|
||||||
# Could be enhanced if the API provides progress information
|
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||||
return None
|
response_model=VeoGenVidPollResponse,
|
||||||
|
|
||||||
# Define the polling operation
|
|
||||||
poll_operation = PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/veo/{model}/poll",
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=VeoGenVidPollRequest,
|
|
||||||
response_model=VeoGenVidPollResponse
|
|
||||||
),
|
|
||||||
completed_statuses=["completed"],
|
|
||||||
failed_statuses=[], # No failed statuses, we'll handle errors after polling
|
|
||||||
status_extractor=status_extractor,
|
status_extractor=status_extractor,
|
||||||
progress_extractor=progress_extractor,
|
data=VeoGenVidPollRequest(
|
||||||
request=VeoGenVidPollRequest(
|
operationName=initial_response.name,
|
||||||
operationName=operation_name
|
|
||||||
),
|
),
|
||||||
auth_kwargs=auth,
|
|
||||||
poll_interval=5.0,
|
poll_interval=5.0,
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Execute the polling operation
|
|
||||||
poll_response = await poll_operation.execute()
|
|
||||||
|
|
||||||
# Now check for errors in the final response
|
# Now check for errors in the final response
|
||||||
# Check for error in poll response
|
# Check for error in poll response
|
||||||
if hasattr(poll_response, 'error') and poll_response.error:
|
if poll_response.error:
|
||||||
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
|
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||||
logging.error(error_message)
|
|
||||||
raise Exception(error_message)
|
|
||||||
|
|
||||||
# Check for RAI filtered content
|
# Check for RAI filtered content
|
||||||
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
|
if (
|
||||||
poll_response.response.raiMediaFilteredCount > 0):
|
hasattr(poll_response.response, "raiMediaFilteredCount")
|
||||||
|
and poll_response.response.raiMediaFilteredCount > 0
|
||||||
|
):
|
||||||
|
|
||||||
# Extract reason message if available
|
# Extract reason message if available
|
||||||
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
|
if (
|
||||||
poll_response.response.raiMediaFilteredReasons):
|
hasattr(poll_response.response, "raiMediaFilteredReasons")
|
||||||
|
and poll_response.response.raiMediaFilteredReasons
|
||||||
|
):
|
||||||
reason = poll_response.response.raiMediaFilteredReasons[0]
|
reason = poll_response.response.raiMediaFilteredReasons[0]
|
||||||
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||||
else:
|
else:
|
||||||
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
|
||||||
|
|
||||||
logging.error(error_message)
|
|
||||||
raise Exception(error_message)
|
raise Exception(error_message)
|
||||||
|
|
||||||
# Extract video data
|
# Extract video data
|
||||||
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
|
if (
|
||||||
|
poll_response.response
|
||||||
|
and hasattr(poll_response.response, "videos")
|
||||||
|
and poll_response.response.videos
|
||||||
|
and len(poll_response.response.videos) > 0
|
||||||
|
):
|
||||||
video = poll_response.response.videos[0]
|
video = poll_response.response.videos[0]
|
||||||
|
|
||||||
# Check if video is provided as base64 or URL
|
# Check if video is provided as base64 or URL
|
||||||
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
|
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||||
# Decode base64 string to bytes
|
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||||
video_data = base64.b64decode(video.bytesBase64Encoded)
|
|
||||||
elif hasattr(video, 'gcsUri') and video.gcsUri:
|
|
||||||
# Download from URL
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(video.gcsUri) as video_response:
|
|
||||||
video_data = await video_response.content.read()
|
|
||||||
else:
|
|
||||||
raise Exception("Video returned but no data or URL was provided")
|
|
||||||
else:
|
|
||||||
raise Exception("Video generation completed but no video was returned")
|
|
||||||
|
|
||||||
if not video_data:
|
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||||
raise Exception("No video data was returned")
|
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||||
|
|
||||||
logging.info("Video generation completed successfully")
|
raise Exception("Video returned but no data or URL was provided")
|
||||||
|
raise Exception("Video generation completed but no video was returned")
|
||||||
# Convert video data to BytesIO object
|
|
||||||
video_io = BytesIO(video_data)
|
|
||||||
|
|
||||||
# Return VideoFromFile object
|
|
||||||
return comfy_io.NodeOutput(VideoFromFile(video_io))
|
|
||||||
|
|
||||||
|
|
||||||
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||||
@ -319,78 +251,83 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="Veo3VideoGenerationNode",
|
node_id="Veo3VideoGenerationNode",
|
||||||
display_name="Google Veo 3 Video Generation",
|
display_name="Google Veo 3 Video Generation",
|
||||||
category="api node/video/Veo",
|
category="api node/video/Veo",
|
||||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Text description of the video",
|
tooltip="Text description of the video",
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=["16:9", "9:16"],
|
options=["16:9", "9:16"],
|
||||||
default="16:9",
|
default="16:9",
|
||||||
tooltip="Aspect ratio of the output video",
|
tooltip="Aspect ratio of the output video",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration_seconds",
|
"duration_seconds",
|
||||||
default=8,
|
default=8,
|
||||||
min=8,
|
min=8,
|
||||||
max=8,
|
max=8,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
tooltip="Duration of the output video in seconds (Veo 3 only supports 8 seconds)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"enhance_prompt",
|
"enhance_prompt",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance",
|
tooltip="Whether to enhance the prompt with AI assistance",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"person_generation",
|
"person_generation",
|
||||||
options=["ALLOW", "BLOCK"],
|
options=["ALLOW", "BLOCK"],
|
||||||
default="ALLOW",
|
default="ALLOW",
|
||||||
tooltip="Whether to allow generating people in the video",
|
tooltip="Whether to allow generating people in the video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=0xFFFFFFFF,
|
max=0xFFFFFFFF,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Optional reference image to guide video generation",
|
tooltip="Optional reference image to guide video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["veo-3.0-generate-001", "veo-3.0-fast-generate-001"],
|
options=[
|
||||||
|
"veo-3.1-generate",
|
||||||
|
"veo-3.1-fast-generate",
|
||||||
|
"veo-3.0-generate-001",
|
||||||
|
"veo-3.0-fast-generate-001",
|
||||||
|
],
|
||||||
default="veo-3.0-generate-001",
|
default="veo-3.0-generate-001",
|
||||||
tooltip="Veo 3 model to use for video generation",
|
tooltip="Veo 3 model to use for video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
tooltip="Generate audio for the video. Supported by all Veo 3 models.",
|
||||||
@ -398,12 +335,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -411,11 +348,12 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
|||||||
|
|
||||||
class VeoExtension(ComfyExtension):
|
class VeoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
VeoVideoGenerationNode,
|
VeoVideoGenerationNode,
|
||||||
Veo3VideoGenerationNode,
|
Veo3VideoGenerationNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> VeoExtension:
|
async def comfy_entrypoint() -> VeoExtension:
|
||||||
return VeoExtension()
|
return VeoExtension()
|
||||||
|
|||||||
@ -1,27 +1,23 @@
|
|||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Literal, TypeVar
|
from typing import Literal, Optional, TypeVar
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import ComfyExtension, io as comfy_io
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.util.validation_utils import (
|
from comfy_api_nodes.util import (
|
||||||
validate_aspect_ratio_closeness,
|
|
||||||
validate_image_dimensions,
|
|
||||||
validate_image_aspect_ratio_range,
|
|
||||||
get_number_of_images,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
HttpMethod,
|
download_url_to_video_output,
|
||||||
SynchronousOperation,
|
get_number_of_images,
|
||||||
PollingOperation,
|
poll_op,
|
||||||
EmptyRequest,
|
sync_op,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
validate_image_aspect_ratio,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
)
|
)
|
||||||
from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi
|
|
||||||
|
|
||||||
|
|
||||||
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video"
|
||||||
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video"
|
||||||
@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations"
|
|||||||
|
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class VideoModelName(str, Enum):
|
class VideoModelName(str, Enum):
|
||||||
vidu_q1 = 'viduq1'
|
vidu_q1 = "viduq1"
|
||||||
|
|
||||||
|
|
||||||
class AspectRatio(str, Enum):
|
class AspectRatio(str, Enum):
|
||||||
@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel):
|
|||||||
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL")
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
|
||||||
created = "created"
|
|
||||||
queueing = "queueing"
|
|
||||||
processing = "processing"
|
|
||||||
success = "success"
|
|
||||||
failed = "failed"
|
|
||||||
|
|
||||||
|
|
||||||
class TaskCreationResponse(BaseModel):
|
class TaskCreationResponse(BaseModel):
|
||||||
task_id: str = Field(...)
|
task_id: str = Field(...)
|
||||||
state: TaskStatus = Field(...)
|
state: str = Field(...)
|
||||||
created_at: str = Field(...)
|
created_at: str = Field(...)
|
||||||
code: Optional[int] = Field(None, description="Error code")
|
code: Optional[int] = Field(None, description="Error code")
|
||||||
|
|
||||||
@ -85,32 +74,11 @@ class TaskResult(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class TaskStatusResponse(BaseModel):
|
class TaskStatusResponse(BaseModel):
|
||||||
state: TaskStatus = Field(...)
|
state: str = Field(...)
|
||||||
err_code: Optional[str] = Field(None)
|
err_code: Optional[str] = Field(None)
|
||||||
creations: list[TaskResult] = Field(..., description="Generated results")
|
creations: list[TaskResult] = Field(..., description="Generated results")
|
||||||
|
|
||||||
|
|
||||||
async def poll_until_finished(
|
|
||||||
auth_kwargs: dict[str, str],
|
|
||||||
api_endpoint: ApiEndpoint[Any, R],
|
|
||||||
result_url_extractor: Optional[Callable[[R], str]] = None,
|
|
||||||
estimated_duration: Optional[int] = None,
|
|
||||||
node_id: Optional[str] = None,
|
|
||||||
) -> R:
|
|
||||||
return await PollingOperation(
|
|
||||||
poll_endpoint=api_endpoint,
|
|
||||||
completed_statuses=[TaskStatus.success.value],
|
|
||||||
failed_statuses=[TaskStatus.failed.value],
|
|
||||||
status_extractor=lambda response: response.state.value,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
result_url_extractor=result_url_extractor,
|
|
||||||
estimated_duration=estimated_duration,
|
|
||||||
node_id=node_id,
|
|
||||||
poll_interval=16.0,
|
|
||||||
max_poll_attempts=256,
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
|
|
||||||
def get_video_url_from_response(response) -> Optional[str]:
|
def get_video_url_from_response(response) -> Optional[str]:
|
||||||
if response.creations:
|
if response.creations:
|
||||||
return response.creations[0].url
|
return response.creations[0].url
|
||||||
@ -127,111 +95,101 @@ def get_video_from_response(response) -> TaskResult:
|
|||||||
|
|
||||||
|
|
||||||
async def execute_task(
|
async def execute_task(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
vidu_endpoint: str,
|
vidu_endpoint: str,
|
||||||
auth_kwargs: Optional[dict[str, str]],
|
|
||||||
payload: TaskCreationRequest,
|
payload: TaskCreationRequest,
|
||||||
estimated_duration: int,
|
estimated_duration: int,
|
||||||
node_id: str,
|
|
||||||
) -> R:
|
) -> R:
|
||||||
response = await SynchronousOperation(
|
response = await sync_op(
|
||||||
endpoint=ApiEndpoint(
|
cls,
|
||||||
path=vidu_endpoint,
|
endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"),
|
||||||
method=HttpMethod.POST,
|
response_model=TaskCreationResponse,
|
||||||
request_model=TaskCreationRequest,
|
data=payload,
|
||||||
response_model=TaskCreationResponse,
|
)
|
||||||
),
|
if response.state == "failed":
|
||||||
request=payload,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
).execute()
|
|
||||||
if response.state == TaskStatus.failed:
|
|
||||||
error_msg = f"Vidu request failed. Code: {response.code}"
|
error_msg = f"Vidu request failed. Code: {response.code}"
|
||||||
logging.error(error_msg)
|
logging.error(error_msg)
|
||||||
raise RuntimeError(error_msg)
|
raise RuntimeError(error_msg)
|
||||||
return await poll_until_finished(
|
return await poll_op(
|
||||||
auth_kwargs,
|
cls,
|
||||||
ApiEndpoint(
|
ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id),
|
||||||
path=VIDU_GET_GENERATION_STATUS % response.task_id,
|
response_model=TaskStatusResponse,
|
||||||
method=HttpMethod.GET,
|
status_extractor=lambda r: r.state,
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=TaskStatusResponse,
|
|
||||||
),
|
|
||||||
result_url_extractor=get_video_url_from_response,
|
|
||||||
estimated_duration=estimated_duration,
|
estimated_duration=estimated_duration,
|
||||||
node_id=node_id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ViduTextToVideoNode(comfy_io.ComfyNode):
|
class ViduTextToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="ViduTextToVideoNode",
|
node_id="ViduTextToVideoNode",
|
||||||
display_name="Vidu Text To Video Generation",
|
display_name="Vidu Text To Video Generation",
|
||||||
category="api node/video/Vidu",
|
category="api node/video/Vidu",
|
||||||
description="Generate video from text prompt",
|
description="Generate video from text prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in VideoModelName],
|
options=VideoModelName,
|
||||||
default=VideoModelName.vidu_q1.value,
|
default=VideoModelName.vidu_q1,
|
||||||
tooltip="Model name",
|
tooltip="Model name",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="A textual description for video generation",
|
tooltip="A textual description for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=5,
|
max=5,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[model.value for model in AspectRatio],
|
options=AspectRatio,
|
||||||
default=AspectRatio.r_16_9.value,
|
default=AspectRatio.r_16_9,
|
||||||
tooltip="The aspect ratio of the output video",
|
tooltip="The aspect ratio of the output video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[model.value for model in Resolution],
|
options=Resolution,
|
||||||
default=Resolution.r_1080p.value,
|
default=Resolution.r_1080p,
|
||||||
tooltip="Supported values may vary by model & duration",
|
tooltip="Supported values may vary by model & duration",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"movement_amplitude",
|
"movement_amplitude",
|
||||||
options=[model.value for model in MovementAmplitude],
|
options=MovementAmplitude,
|
||||||
default=MovementAmplitude.auto.value,
|
default=MovementAmplitude.auto,
|
||||||
tooltip="The movement amplitude of objects in the frame",
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -246,7 +204,7 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise ValueError("The prompt field is required and cannot be empty.")
|
raise ValueError("The prompt field is required and cannot be empty.")
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
@ -258,84 +216,80 @@ class ViduTextToVideoNode(comfy_io.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
movement_amplitude=movement_amplitude,
|
movement_amplitude=movement_amplitude,
|
||||||
)
|
)
|
||||||
auth = {
|
results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320)
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id)
|
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
|
||||||
|
|
||||||
|
|
||||||
class ViduImageToVideoNode(comfy_io.ComfyNode):
|
class ViduImageToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="ViduImageToVideoNode",
|
node_id="ViduImageToVideoNode",
|
||||||
display_name="Vidu Image To Video Generation",
|
display_name="Vidu Image To Video Generation",
|
||||||
category="api node/video/Vidu",
|
category="api node/video/Vidu",
|
||||||
description="Generate video from image and optional prompt",
|
description="Generate video from image and optional prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in VideoModelName],
|
options=VideoModelName,
|
||||||
default=VideoModelName.vidu_q1.value,
|
default=VideoModelName.vidu_q1,
|
||||||
tooltip="Model name",
|
tooltip="Model name",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="An image to be used as the start frame of the generated video",
|
tooltip="An image to be used as the start frame of the generated video",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="A textual description for video generation",
|
tooltip="A textual description for video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=5,
|
max=5,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[model.value for model in Resolution],
|
options=Resolution,
|
||||||
default=Resolution.r_1080p.value,
|
default=Resolution.r_1080p,
|
||||||
tooltip="Supported values may vary by model & duration",
|
tooltip="Supported values may vary by model & duration",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"movement_amplitude",
|
"movement_amplitude",
|
||||||
options=[model.value for model in MovementAmplitude],
|
options=MovementAmplitude,
|
||||||
default=MovementAmplitude.auto.value,
|
default=MovementAmplitude.auto.value,
|
||||||
tooltip="The movement amplitude of objects in the frame",
|
tooltip="The movement amplitude of objects in the frame",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -350,10 +304,10 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if get_number_of_images(image) > 1:
|
if get_number_of_images(image) > 1:
|
||||||
raise ValueError("Only one input image is allowed.")
|
raise ValueError("Only one input image is allowed.")
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -362,81 +316,77 @@ class ViduImageToVideoNode(comfy_io.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
movement_amplitude=movement_amplitude,
|
movement_amplitude=movement_amplitude,
|
||||||
)
|
)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
payload.images = await upload_images_to_comfyapi(
|
payload.images = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
image,
|
image,
|
||||||
max_images=1,
|
max_images=1,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120)
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
class ViduReferenceVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="ViduReferenceVideoNode",
|
node_id="ViduReferenceVideoNode",
|
||||||
display_name="Vidu Reference To Video Generation",
|
display_name="Vidu Reference To Video Generation",
|
||||||
category="api node/video/Vidu",
|
category="api node/video/Vidu",
|
||||||
description="Generate video from multiple images and prompt",
|
description="Generate video from multiple images and prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in VideoModelName],
|
options=VideoModelName,
|
||||||
default=VideoModelName.vidu_q1.value,
|
default=VideoModelName.vidu_q1,
|
||||||
tooltip="Model name",
|
tooltip="Model name",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"images",
|
"images",
|
||||||
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
tooltip="Images to use as references to generate a video with consistent subjects (max 7 images).",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="A textual description for video generation",
|
tooltip="A textual description for video generation",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=5,
|
max=5,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"aspect_ratio",
|
"aspect_ratio",
|
||||||
options=[model.value for model in AspectRatio],
|
options=AspectRatio,
|
||||||
default=AspectRatio.r_16_9.value,
|
default=AspectRatio.r_16_9,
|
||||||
tooltip="The aspect ratio of the output video",
|
tooltip="The aspect ratio of the output video",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[model.value for model in Resolution],
|
options=[model.value for model in Resolution],
|
||||||
default=Resolution.r_1080p.value,
|
default=Resolution.r_1080p.value,
|
||||||
tooltip="Supported values may vary by model & duration",
|
tooltip="Supported values may vary by model & duration",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"movement_amplitude",
|
"movement_amplitude",
|
||||||
options=[model.value for model in MovementAmplitude],
|
options=[model.value for model in MovementAmplitude],
|
||||||
default=MovementAmplitude.auto.value,
|
default=MovementAmplitude.auto.value,
|
||||||
@ -445,12 +395,12 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -466,14 +416,14 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
if not prompt:
|
if not prompt:
|
||||||
raise ValueError("The prompt field is required and cannot be empty.")
|
raise ValueError("The prompt field is required and cannot be empty.")
|
||||||
a = get_number_of_images(images)
|
a = get_number_of_images(images)
|
||||||
if a > 7:
|
if a > 7:
|
||||||
raise ValueError("Too many images, maximum allowed is 7.")
|
raise ValueError("Too many images, maximum allowed is 7.")
|
||||||
for image in images:
|
for image in images:
|
||||||
validate_image_aspect_ratio_range(image, (1, 4), (4, 1))
|
validate_image_aspect_ratio(image, (1, 4), (4, 1))
|
||||||
validate_image_dimensions(image, min_width=128, min_height=128)
|
validate_image_dimensions(image, min_width=128, min_height=128)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@ -484,79 +434,75 @@ class ViduReferenceVideoNode(comfy_io.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
movement_amplitude=movement_amplitude,
|
movement_amplitude=movement_amplitude,
|
||||||
)
|
)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
payload.images = await upload_images_to_comfyapi(
|
payload.images = await upload_images_to_comfyapi(
|
||||||
|
cls,
|
||||||
images,
|
images,
|
||||||
max_images=7,
|
max_images=7,
|
||||||
mime_type="image/png",
|
mime_type="image/png",
|
||||||
auth_kwargs=auth,
|
|
||||||
)
|
)
|
||||||
results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id)
|
results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120)
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="ViduStartEndToVideoNode",
|
node_id="ViduStartEndToVideoNode",
|
||||||
display_name="Vidu Start End To Video Generation",
|
display_name="Vidu Start End To Video Generation",
|
||||||
category="api node/video/Vidu",
|
category="api node/video/Vidu",
|
||||||
description="Generate a video from start and end frames and a prompt",
|
description="Generate a video from start and end frames and a prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[model.value for model in VideoModelName],
|
options=[model.value for model in VideoModelName],
|
||||||
default=VideoModelName.vidu_q1.value,
|
default=VideoModelName.vidu_q1.value,
|
||||||
tooltip="Model name",
|
tooltip="Model name",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"first_frame",
|
"first_frame",
|
||||||
tooltip="Start frame",
|
tooltip="Start frame",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"end_frame",
|
"end_frame",
|
||||||
tooltip="End frame",
|
tooltip="End frame",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="A textual description for video generation",
|
tooltip="A textual description for video generation",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=5,
|
max=5,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Duration of the output video in seconds",
|
tooltip="Duration of the output video in seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed for video generation (0 for random)",
|
tooltip="Seed for video generation (0 for random)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[model.value for model in Resolution],
|
options=[model.value for model in Resolution],
|
||||||
default=Resolution.r_1080p.value,
|
default=Resolution.r_1080p.value,
|
||||||
tooltip="Supported values may vary by model & duration",
|
tooltip="Supported values may vary by model & duration",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"movement_amplitude",
|
"movement_amplitude",
|
||||||
options=[model.value for model in MovementAmplitude],
|
options=[model.value for model in MovementAmplitude],
|
||||||
default=MovementAmplitude.auto.value,
|
default=MovementAmplitude.auto.value,
|
||||||
@ -565,12 +511,12 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -586,8 +532,8 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
|||||||
seed: int,
|
seed: int,
|
||||||
resolution: str,
|
resolution: str,
|
||||||
movement_amplitude: str,
|
movement_amplitude: str,
|
||||||
) -> comfy_io.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||||
payload = TaskCreationRequest(
|
payload = TaskCreationRequest(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@ -596,21 +542,17 @@ class ViduStartEndToVideoNode(comfy_io.ComfyNode):
|
|||||||
resolution=resolution,
|
resolution=resolution,
|
||||||
movement_amplitude=movement_amplitude,
|
movement_amplitude=movement_amplitude,
|
||||||
)
|
)
|
||||||
auth = {
|
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
|
||||||
}
|
|
||||||
payload.images = [
|
payload.images = [
|
||||||
(await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0]
|
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||||
for frame in (first_frame, end_frame)
|
for frame in (first_frame, end_frame)
|
||||||
]
|
]
|
||||||
results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id)
|
results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96)
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url))
|
||||||
|
|
||||||
|
|
||||||
class ViduExtension(ComfyExtension):
|
class ViduExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
ViduTextToVideoNode,
|
ViduTextToVideoNode,
|
||||||
ViduImageToVideoNode,
|
ViduImageToVideoNode,
|
||||||
@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension):
|
|||||||
ViduStartEndToVideoNode,
|
ViduStartEndToVideoNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> ViduExtension:
|
async def comfy_entrypoint() -> ViduExtension:
|
||||||
return ViduExtension()
|
return ViduExtension()
|
||||||
|
|||||||
@ -1,28 +1,24 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Optional, Type, Union
|
from typing import Optional
|
||||||
from typing_extensions import override
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from comfy_api.latest import ComfyExtension, Input, io as comfy_io
|
from typing_extensions import override
|
||||||
from comfy_api_nodes.apis.client import (
|
|
||||||
ApiEndpoint,
|
|
||||||
HttpMethod,
|
|
||||||
SynchronousOperation,
|
|
||||||
PollingOperation,
|
|
||||||
EmptyRequest,
|
|
||||||
R,
|
|
||||||
T,
|
|
||||||
)
|
|
||||||
from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration
|
|
||||||
|
|
||||||
from comfy_api_nodes.apinode_utils import (
|
from comfy_api.latest import IO, ComfyExtension, Input
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
audio_to_base64_string,
|
||||||
download_url_to_image_tensor,
|
download_url_to_image_tensor,
|
||||||
download_url_to_video_output,
|
download_url_to_video_output,
|
||||||
|
get_number_of_images,
|
||||||
|
poll_op,
|
||||||
|
sync_op,
|
||||||
tensor_to_base64_string,
|
tensor_to_base64_string,
|
||||||
audio_to_base64_string,
|
validate_audio_duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Text2ImageInputField(BaseModel):
|
class Text2ImageInputField(BaseModel):
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
negative_prompt: Optional[str] = Field(None)
|
negative_prompt: Optional[str] = Field(None)
|
||||||
@ -146,84 +142,38 @@ class VideoTaskStatusResponse(BaseModel):
|
|||||||
request_id: str = Field(...)
|
request_id: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)')
|
RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)")
|
||||||
|
|
||||||
|
|
||||||
async def process_task(
|
class WanTextToImageApi(IO.ComfyNode):
|
||||||
auth_kwargs: dict[str, str],
|
|
||||||
url: str,
|
|
||||||
request_model: Type[T],
|
|
||||||
response_model: Type[R],
|
|
||||||
payload: Union[
|
|
||||||
Text2ImageTaskCreationRequest,
|
|
||||||
Image2ImageTaskCreationRequest,
|
|
||||||
Text2VideoTaskCreationRequest,
|
|
||||||
Image2VideoTaskCreationRequest,
|
|
||||||
],
|
|
||||||
node_id: str,
|
|
||||||
estimated_duration: int,
|
|
||||||
poll_interval: int,
|
|
||||||
) -> Type[R]:
|
|
||||||
initial_response = await SynchronousOperation(
|
|
||||||
endpoint=ApiEndpoint(
|
|
||||||
path=url,
|
|
||||||
method=HttpMethod.POST,
|
|
||||||
request_model=request_model,
|
|
||||||
response_model=TaskCreationResponse,
|
|
||||||
),
|
|
||||||
request=payload,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
if not initial_response.output:
|
|
||||||
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
|
||||||
|
|
||||||
return await PollingOperation(
|
|
||||||
poll_endpoint=ApiEndpoint(
|
|
||||||
path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}",
|
|
||||||
method=HttpMethod.GET,
|
|
||||||
request_model=EmptyRequest,
|
|
||||||
response_model=response_model,
|
|
||||||
),
|
|
||||||
completed_statuses=["SUCCEEDED"],
|
|
||||||
failed_statuses=["FAILED", "CANCELED", "UNKNOWN"],
|
|
||||||
status_extractor=lambda x: x.output.task_status,
|
|
||||||
estimated_duration=estimated_duration,
|
|
||||||
poll_interval=poll_interval,
|
|
||||||
node_id=node_id,
|
|
||||||
auth_kwargs=auth_kwargs,
|
|
||||||
).execute()
|
|
||||||
|
|
||||||
|
|
||||||
class WanTextToImageApi(comfy_io.ComfyNode):
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="WanTextToImageApi",
|
node_id="WanTextToImageApi",
|
||||||
display_name="Wan Text to Image",
|
display_name="Wan Text to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates image based on text prompt.",
|
description="Generates image based on text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2i-preview"],
|
options=["wan2.5-t2i-preview"],
|
||||||
default="wan2.5-t2i-preview",
|
default="wan2.5-t2i-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"width",
|
"width",
|
||||||
default=1024,
|
default=1024,
|
||||||
min=768,
|
min=768,
|
||||||
@ -231,7 +181,7 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
|||||||
step=32,
|
step=32,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"height",
|
"height",
|
||||||
default=1024,
|
default=1024,
|
||||||
min=768,
|
min=768,
|
||||||
@ -239,37 +189,37 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
|||||||
step=32,
|
step=32,
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to use for generation.",
|
tooltip="Seed to use for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -286,59 +236,61 @@ class WanTextToImageApi(comfy_io.ComfyNode):
|
|||||||
prompt_extend: bool = True,
|
prompt_extend: bool = True,
|
||||||
watermark: bool = True,
|
watermark: bool = True,
|
||||||
):
|
):
|
||||||
payload = Text2ImageTaskCreationRequest(
|
initial_response = await sync_op(
|
||||||
model=model,
|
cls,
|
||||||
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"),
|
||||||
parameters=Txt2ImageParametersField(
|
response_model=TaskCreationResponse,
|
||||||
size=f"{width}*{height}",
|
data=Text2ImageTaskCreationRequest(
|
||||||
seed=seed,
|
model=model,
|
||||||
prompt_extend=prompt_extend,
|
input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt),
|
||||||
watermark=watermark,
|
parameters=Txt2ImageParametersField(
|
||||||
|
size=f"{width}*{height}",
|
||||||
|
seed=seed,
|
||||||
|
prompt_extend=prompt_extend,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = await process_task(
|
if not initial_response.output:
|
||||||
{
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
response = await poll_op(
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
cls,
|
||||||
},
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
"/proxy/wan/api/v1/services/aigc/text2image/image-synthesis",
|
|
||||||
request_model=Text2ImageTaskCreationRequest,
|
|
||||||
response_model=ImageTaskStatusResponse,
|
response_model=ImageTaskStatusResponse,
|
||||||
payload=payload,
|
status_extractor=lambda x: x.output.task_status,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=9,
|
estimated_duration=9,
|
||||||
poll_interval=3,
|
poll_interval=3,
|
||||||
)
|
)
|
||||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||||
|
|
||||||
|
|
||||||
class WanImageToImageApi(comfy_io.ComfyNode):
|
class WanImageToImageApi(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="WanImageToImageApi",
|
node_id="WanImageToImageApi",
|
||||||
display_name="Wan Image to Image",
|
display_name="Wan Image to Image",
|
||||||
category="api node/image/Wan",
|
category="api node/image/Wan",
|
||||||
description="Generates an image from one or two input images and a text prompt. "
|
description="Generates an image from one or two input images and a text prompt. "
|
||||||
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
"The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2i-preview"],
|
options=["wan2.5-i2i-preview"],
|
||||||
default="wan2.5-i2i-preview",
|
default="wan2.5-i2i-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
tooltip="Single-image editing or multi-image fusion, maximum 2 images.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
@ -346,7 +298,7 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
|||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
# redo this later as an optional combo of recommended resolutions
|
# redo this later as an optional combo of recommended resolutions
|
||||||
# comfy_io.Int.Input(
|
# IO.Int.Input(
|
||||||
# "width",
|
# "width",
|
||||||
# default=1280,
|
# default=1280,
|
||||||
# min=384,
|
# min=384,
|
||||||
@ -354,7 +306,7 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
|||||||
# step=16,
|
# step=16,
|
||||||
# optional=True,
|
# optional=True,
|
||||||
# ),
|
# ),
|
||||||
# comfy_io.Int.Input(
|
# IO.Int.Input(
|
||||||
# "height",
|
# "height",
|
||||||
# default=1280,
|
# default=1280,
|
||||||
# min=384,
|
# min=384,
|
||||||
@ -362,31 +314,31 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
|||||||
# step=16,
|
# step=16,
|
||||||
# optional=True,
|
# optional=True,
|
||||||
# ),
|
# ),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to use for generation.",
|
tooltip="Seed to use for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Image.Output(),
|
IO.Image.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -408,61 +360,63 @@ class WanImageToImageApi(comfy_io.ComfyNode):
|
|||||||
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
raise ValueError(f"Expected 1 or 2 input images, got {n_images}.")
|
||||||
images = []
|
images = []
|
||||||
for i in image:
|
for i in image:
|
||||||
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096))
|
images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096))
|
||||||
payload = Image2ImageTaskCreationRequest(
|
initial_response = await sync_op(
|
||||||
model=model,
|
cls,
|
||||||
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"),
|
||||||
parameters=Image2ImageParametersField(
|
response_model=TaskCreationResponse,
|
||||||
# size=f"{width}*{height}",
|
data=Image2ImageTaskCreationRequest(
|
||||||
seed=seed,
|
model=model,
|
||||||
watermark=watermark,
|
input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images),
|
||||||
|
parameters=Image2ImageParametersField(
|
||||||
|
# size=f"{width}*{height}",
|
||||||
|
seed=seed,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = await process_task(
|
if not initial_response.output:
|
||||||
{
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
response = await poll_op(
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
cls,
|
||||||
},
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
"/proxy/wan/api/v1/services/aigc/image2image/image-synthesis",
|
|
||||||
request_model=Image2ImageTaskCreationRequest,
|
|
||||||
response_model=ImageTaskStatusResponse,
|
response_model=ImageTaskStatusResponse,
|
||||||
payload=payload,
|
status_extractor=lambda x: x.output.task_status,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=42,
|
estimated_duration=42,
|
||||||
poll_interval=3,
|
poll_interval=4,
|
||||||
)
|
)
|
||||||
return comfy_io.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url)))
|
||||||
|
|
||||||
|
|
||||||
class WanTextToVideoApi(comfy_io.ComfyNode):
|
class WanTextToVideoApi(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="WanTextToVideoApi",
|
node_id="WanTextToVideoApi",
|
||||||
display_name="Wan Text to Video",
|
display_name="Wan Text to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on text prompt.",
|
description="Generates video based on text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-t2v-preview"],
|
options=["wan2.5-t2v-preview"],
|
||||||
default="wan2.5-t2v-preview",
|
default="wan2.5-t2v-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"size",
|
"size",
|
||||||
options=[
|
options=[
|
||||||
"480p: 1:1 (624x624)",
|
"480p: 1:1 (624x624)",
|
||||||
@ -482,58 +436,58 @@ class WanTextToVideoApi(comfy_io.ComfyNode):
|
|||||||
default="480p: 1:1 (624x624)",
|
default="480p: 1:1 (624x624)",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=10,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="Available durations: 5 and 10 seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to use for generation.",
|
tooltip="Seed to use for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If there is no audio input, generate audio automatically.",
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -557,66 +511,69 @@ class WanTextToVideoApi(comfy_io.ComfyNode):
|
|||||||
if audio is not None:
|
if audio is not None:
|
||||||
validate_audio_duration(audio, 3.0, 29.0)
|
validate_audio_duration(audio, 3.0, 29.0)
|
||||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||||
payload = Text2VideoTaskCreationRequest(
|
|
||||||
model=model,
|
initial_response = await sync_op(
|
||||||
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
cls,
|
||||||
parameters=Text2VideoParametersField(
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||||
size=f"{width}*{height}",
|
response_model=TaskCreationResponse,
|
||||||
duration=duration,
|
data=Text2VideoTaskCreationRequest(
|
||||||
seed=seed,
|
model=model,
|
||||||
audio=generate_audio,
|
input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url),
|
||||||
prompt_extend=prompt_extend,
|
parameters=Text2VideoParametersField(
|
||||||
watermark=watermark,
|
size=f"{width}*{height}",
|
||||||
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
audio=generate_audio,
|
||||||
|
prompt_extend=prompt_extend,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = await process_task(
|
if not initial_response.output:
|
||||||
{
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
response = await poll_op(
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
cls,
|
||||||
},
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
|
||||||
request_model=Text2VideoTaskCreationRequest,
|
|
||||||
response_model=VideoTaskStatusResponse,
|
response_model=VideoTaskStatusResponse,
|
||||||
payload=payload,
|
status_extractor=lambda x: x.output.task_status,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=120 * int(duration / 5),
|
estimated_duration=120 * int(duration / 5),
|
||||||
poll_interval=6,
|
poll_interval=6,
|
||||||
)
|
)
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
class WanImageToVideoApi(comfy_io.ComfyNode):
|
class WanImageToVideoApi(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return comfy_io.Schema(
|
return IO.Schema(
|
||||||
node_id="WanImageToVideoApi",
|
node_id="WanImageToVideoApi",
|
||||||
display_name="Wan Image to Video",
|
display_name="Wan Image to Video",
|
||||||
category="api node/video/Wan",
|
category="api node/video/Wan",
|
||||||
description="Generates video based on the first frame and text prompt.",
|
description="Generates video based on the first frame and text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=["wan2.5-i2v-preview"],
|
options=["wan2.5-i2v-preview"],
|
||||||
default="wan2.5-i2v-preview",
|
default="wan2.5-i2v-preview",
|
||||||
tooltip="Model to use.",
|
tooltip="Model to use.",
|
||||||
),
|
),
|
||||||
comfy_io.Image.Input(
|
IO.Image.Input(
|
||||||
"image",
|
"image",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
tooltip="Prompt used to describe the elements and visual features, supports English/Chinese.",
|
||||||
),
|
),
|
||||||
comfy_io.String.Input(
|
IO.String.Input(
|
||||||
"negative_prompt",
|
"negative_prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
default="",
|
default="",
|
||||||
tooltip="Negative text prompt to guide what to avoid.",
|
tooltip="Negative text prompt to guide what to avoid.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Combo.Input(
|
IO.Combo.Input(
|
||||||
"resolution",
|
"resolution",
|
||||||
options=[
|
options=[
|
||||||
"480P",
|
"480P",
|
||||||
@ -626,58 +583,58 @@ class WanImageToVideoApi(comfy_io.ComfyNode):
|
|||||||
default="480P",
|
default="480P",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"duration",
|
"duration",
|
||||||
default=5,
|
default=5,
|
||||||
min=5,
|
min=5,
|
||||||
max=10,
|
max=10,
|
||||||
step=5,
|
step=5,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
tooltip="Available durations: 5 and 10 seconds",
|
tooltip="Available durations: 5 and 10 seconds",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Audio.Input(
|
IO.Audio.Input(
|
||||||
"audio",
|
"audio",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
tooltip="Audio must contain a clear, loud voice, without extraneous noise, background music.",
|
||||||
),
|
),
|
||||||
comfy_io.Int.Input(
|
IO.Int.Input(
|
||||||
"seed",
|
"seed",
|
||||||
default=0,
|
default=0,
|
||||||
min=0,
|
min=0,
|
||||||
max=2147483647,
|
max=2147483647,
|
||||||
step=1,
|
step=1,
|
||||||
display_mode=comfy_io.NumberDisplay.number,
|
display_mode=IO.NumberDisplay.number,
|
||||||
control_after_generate=True,
|
control_after_generate=True,
|
||||||
tooltip="Seed to use for generation.",
|
tooltip="Seed to use for generation.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"generate_audio",
|
"generate_audio",
|
||||||
default=False,
|
default=False,
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="If there is no audio input, generate audio automatically.",
|
tooltip="If there is no audio input, generate audio automatically.",
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"prompt_extend",
|
"prompt_extend",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to enhance the prompt with AI assistance.",
|
tooltip="Whether to enhance the prompt with AI assistance.",
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
comfy_io.Boolean.Input(
|
IO.Boolean.Input(
|
||||||
"watermark",
|
"watermark",
|
||||||
default=True,
|
default=True,
|
||||||
tooltip="Whether to add an \"AI generated\" watermark to the result.",
|
tooltip='Whether to add an "AI generated" watermark to the result.',
|
||||||
optional=True,
|
optional=True,
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
comfy_io.Video.Output(),
|
IO.Video.Output(),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
comfy_io.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
comfy_io.Hidden.api_key_comfy_org,
|
IO.Hidden.api_key_comfy_org,
|
||||||
comfy_io.Hidden.unique_id,
|
IO.Hidden.unique_id,
|
||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
)
|
)
|
||||||
@ -699,44 +656,46 @@ class WanImageToVideoApi(comfy_io.ComfyNode):
|
|||||||
):
|
):
|
||||||
if get_number_of_images(image) != 1:
|
if get_number_of_images(image) != 1:
|
||||||
raise ValueError("Exactly one input image is required.")
|
raise ValueError("Exactly one input image is required.")
|
||||||
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000)
|
image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000)
|
||||||
audio_url = None
|
audio_url = None
|
||||||
if audio is not None:
|
if audio is not None:
|
||||||
validate_audio_duration(audio, 3.0, 29.0)
|
validate_audio_duration(audio, 3.0, 29.0)
|
||||||
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame")
|
||||||
payload = Image2VideoTaskCreationRequest(
|
initial_response = await sync_op(
|
||||||
model=model,
|
cls,
|
||||||
input=Image2VideoInputField(
|
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
|
||||||
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
response_model=TaskCreationResponse,
|
||||||
),
|
data=Image2VideoTaskCreationRequest(
|
||||||
parameters=Image2VideoParametersField(
|
model=model,
|
||||||
resolution=resolution,
|
input=Image2VideoInputField(
|
||||||
duration=duration,
|
prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url
|
||||||
seed=seed,
|
),
|
||||||
audio=generate_audio,
|
parameters=Image2VideoParametersField(
|
||||||
prompt_extend=prompt_extend,
|
resolution=resolution,
|
||||||
watermark=watermark,
|
duration=duration,
|
||||||
|
seed=seed,
|
||||||
|
audio=generate_audio,
|
||||||
|
prompt_extend=prompt_extend,
|
||||||
|
watermark=watermark,
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = await process_task(
|
if not initial_response.output:
|
||||||
{
|
raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}")
|
||||||
"auth_token": cls.hidden.auth_token_comfy_org,
|
response = await poll_op(
|
||||||
"comfy_api_key": cls.hidden.api_key_comfy_org,
|
cls,
|
||||||
},
|
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
|
||||||
"/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis",
|
|
||||||
request_model=Image2VideoTaskCreationRequest,
|
|
||||||
response_model=VideoTaskStatusResponse,
|
response_model=VideoTaskStatusResponse,
|
||||||
payload=payload,
|
status_extractor=lambda x: x.output.task_status,
|
||||||
node_id=cls.hidden.unique_id,
|
|
||||||
estimated_duration=120 * int(duration / 5),
|
estimated_duration=120 * int(duration / 5),
|
||||||
poll_interval=6,
|
poll_interval=6,
|
||||||
)
|
)
|
||||||
return comfy_io.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))
|
||||||
|
|
||||||
|
|
||||||
class WanApiExtension(ComfyExtension):
|
class WanApiExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[comfy_io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
WanTextToImageApi,
|
WanTextToImageApi,
|
||||||
WanImageToImageApi,
|
WanImageToImageApi,
|
||||||
|
|||||||
@ -0,0 +1,97 @@
|
|||||||
|
from ._helpers import get_fs_object_size
|
||||||
|
from .client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
poll_op,
|
||||||
|
poll_op_raw,
|
||||||
|
sync_op,
|
||||||
|
sync_op_raw,
|
||||||
|
)
|
||||||
|
from .conversions import (
|
||||||
|
audio_bytes_to_audio_input,
|
||||||
|
audio_input_to_mp3,
|
||||||
|
audio_to_base64_string,
|
||||||
|
bytesio_to_image_tensor,
|
||||||
|
downscale_image_tensor,
|
||||||
|
image_tensor_pair_to_batch,
|
||||||
|
pil_to_bytesio,
|
||||||
|
resize_mask_to_image,
|
||||||
|
tensor_to_base64_string,
|
||||||
|
tensor_to_bytesio,
|
||||||
|
tensor_to_pil,
|
||||||
|
text_filepath_to_base64_string,
|
||||||
|
text_filepath_to_data_uri,
|
||||||
|
trim_video,
|
||||||
|
video_to_base64_string,
|
||||||
|
)
|
||||||
|
from .download_helpers import (
|
||||||
|
download_url_as_bytesio,
|
||||||
|
download_url_to_bytesio,
|
||||||
|
download_url_to_image_tensor,
|
||||||
|
download_url_to_video_output,
|
||||||
|
)
|
||||||
|
from .upload_helpers import (
|
||||||
|
upload_audio_to_comfyapi,
|
||||||
|
upload_file_to_comfyapi,
|
||||||
|
upload_images_to_comfyapi,
|
||||||
|
upload_video_to_comfyapi,
|
||||||
|
)
|
||||||
|
from .validation_utils import (
|
||||||
|
get_number_of_images,
|
||||||
|
validate_aspect_ratio_string,
|
||||||
|
validate_audio_duration,
|
||||||
|
validate_container_format_is_mp4,
|
||||||
|
validate_image_aspect_ratio,
|
||||||
|
validate_image_dimensions,
|
||||||
|
validate_images_aspect_ratio_closeness,
|
||||||
|
validate_string,
|
||||||
|
validate_video_dimensions,
|
||||||
|
validate_video_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# API client
|
||||||
|
"ApiEndpoint",
|
||||||
|
"poll_op",
|
||||||
|
"poll_op_raw",
|
||||||
|
"sync_op",
|
||||||
|
"sync_op_raw",
|
||||||
|
# Upload helpers
|
||||||
|
"upload_audio_to_comfyapi",
|
||||||
|
"upload_file_to_comfyapi",
|
||||||
|
"upload_images_to_comfyapi",
|
||||||
|
"upload_video_to_comfyapi",
|
||||||
|
# Download helpers
|
||||||
|
"download_url_as_bytesio",
|
||||||
|
"download_url_to_bytesio",
|
||||||
|
"download_url_to_image_tensor",
|
||||||
|
"download_url_to_video_output",
|
||||||
|
# Conversions
|
||||||
|
"audio_bytes_to_audio_input",
|
||||||
|
"audio_input_to_mp3",
|
||||||
|
"audio_to_base64_string",
|
||||||
|
"bytesio_to_image_tensor",
|
||||||
|
"downscale_image_tensor",
|
||||||
|
"image_tensor_pair_to_batch",
|
||||||
|
"pil_to_bytesio",
|
||||||
|
"resize_mask_to_image",
|
||||||
|
"tensor_to_base64_string",
|
||||||
|
"tensor_to_bytesio",
|
||||||
|
"tensor_to_pil",
|
||||||
|
"text_filepath_to_base64_string",
|
||||||
|
"text_filepath_to_data_uri",
|
||||||
|
"trim_video",
|
||||||
|
"video_to_base64_string",
|
||||||
|
# Validation utilities
|
||||||
|
"get_number_of_images",
|
||||||
|
"validate_aspect_ratio_string",
|
||||||
|
"validate_audio_duration",
|
||||||
|
"validate_container_format_is_mp4",
|
||||||
|
"validate_image_aspect_ratio",
|
||||||
|
"validate_image_dimensions",
|
||||||
|
"validate_images_aspect_ratio_closeness",
|
||||||
|
"validate_string",
|
||||||
|
"validate_video_dimensions",
|
||||||
|
"validate_video_duration",
|
||||||
|
# Misc functions
|
||||||
|
"get_fs_object_size",
|
||||||
|
]
|
||||||
71
comfy_api_nodes/util/_helpers.py
Normal file
71
comfy_api_nodes/util/_helpers.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy.model_management import processing_interrupted
|
||||||
|
from comfy_api.latest import IO
|
||||||
|
|
||||||
|
from .common_exceptions import ProcessingInterrupted
|
||||||
|
|
||||||
|
|
||||||
|
def is_processing_interrupted() -> bool:
|
||||||
|
"""Return True if user/runtime requested interruption."""
|
||||||
|
return processing_interrupted()
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_id(node_cls: type[IO.ComfyNode]) -> str:
|
||||||
|
return node_cls.hidden.unique_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||||
|
if node_cls.hidden.auth_token_comfy_org:
|
||||||
|
return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"}
|
||||||
|
if node_cls.hidden.api_key_comfy_org:
|
||||||
|
return {"X-API-KEY": node_cls.hidden.api_key_comfy_org}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def default_base_url() -> str:
|
||||||
|
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||||
|
|
||||||
|
|
||||||
|
async def sleep_with_interrupt(
|
||||||
|
seconds: float,
|
||||||
|
node_cls: Optional[type[IO.ComfyNode]],
|
||||||
|
label: Optional[str] = None,
|
||||||
|
start_ts: Optional[float] = None,
|
||||||
|
estimated_total: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sleep in 1s slices while:
|
||||||
|
- Checking for interruption (raises ProcessingInterrupted).
|
||||||
|
- Optionally emitting time progress via display_callback (if provided).
|
||||||
|
"""
|
||||||
|
end = time.monotonic() + seconds
|
||||||
|
while True:
|
||||||
|
if is_processing_interrupted():
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
now = time.monotonic()
|
||||||
|
if start_ts is not None and label and display_callback:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
display_callback(node_cls, label, int(now - start_ts), estimated_total)
|
||||||
|
if now >= end:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(min(1.0, end - now))
|
||||||
|
|
||||||
|
|
||||||
|
def mimetype_to_extension(mime_type: str) -> str:
|
||||||
|
"""Converts a MIME type to a file extension."""
|
||||||
|
return mime_type.split("/")[-1].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||||
|
if isinstance(path_or_object, str):
|
||||||
|
return os.path.getsize(path_or_object)
|
||||||
|
return len(path_or_object.getvalue())
|
||||||
936
comfy_api_nodes/util/client.py
Normal file
936
comfy_api_nodes/util/client.py
Normal file
@ -0,0 +1,936 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from comfy import utils
|
||||||
|
from comfy_api.latest import IO
|
||||||
|
from server import PromptServer
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
|
from ._helpers import (
|
||||||
|
default_base_url,
|
||||||
|
get_auth_header,
|
||||||
|
get_node_id,
|
||||||
|
is_processing_interrupted,
|
||||||
|
sleep_with_interrupt,
|
||||||
|
)
|
||||||
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||||
|
|
||||||
|
M = TypeVar("M", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEndpoint:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||||
|
*,
|
||||||
|
query_params: Optional[dict[str, Any]] = None,
|
||||||
|
headers: Optional[dict[str, str]] = None,
|
||||||
|
):
|
||||||
|
self.path = path
|
||||||
|
self.method = method
|
||||||
|
self.query_params = query_params or {}
|
||||||
|
self.headers = headers or {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _RequestConfig:
|
||||||
|
node_cls: type[IO.ComfyNode]
|
||||||
|
endpoint: ApiEndpoint
|
||||||
|
timeout: float
|
||||||
|
content_type: str
|
||||||
|
data: Optional[dict[str, Any]]
|
||||||
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||||
|
multipart_parser: Optional[Callable]
|
||||||
|
max_retries: int
|
||||||
|
retry_delay: float
|
||||||
|
retry_backoff: float
|
||||||
|
wait_label: str = "Waiting"
|
||||||
|
monitor_progress: bool = True
|
||||||
|
estimated_total: Optional[int] = None
|
||||||
|
final_label_on_success: Optional[str] = "Completed"
|
||||||
|
progress_origin_ts: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _PollUIState:
|
||||||
|
started: float
|
||||||
|
status_label: str = "Queued"
|
||||||
|
is_queued: bool = True
|
||||||
|
price: Optional[float] = None
|
||||||
|
estimated_duration: Optional[int] = None
|
||||||
|
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||||
|
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||||
|
|
||||||
|
|
||||||
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
|
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||||||
|
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||||
|
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_op(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
endpoint: ApiEndpoint,
|
||||||
|
*,
|
||||||
|
response_model: Type[M],
|
||||||
|
data: Optional[BaseModel] = None,
|
||||||
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||||
|
content_type: str = "application/json",
|
||||||
|
timeout: float = 3600.0,
|
||||||
|
multipart_parser: Optional[Callable] = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
retry_backoff: float = 2.0,
|
||||||
|
wait_label: str = "Waiting for server",
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
final_label_on_success: Optional[str] = "Completed",
|
||||||
|
progress_origin_ts: Optional[float] = None,
|
||||||
|
monitor_progress: bool = True,
|
||||||
|
) -> M:
|
||||||
|
raw = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
endpoint,
|
||||||
|
data=data,
|
||||||
|
files=files,
|
||||||
|
content_type=content_type,
|
||||||
|
timeout=timeout,
|
||||||
|
multipart_parser=multipart_parser,
|
||||||
|
max_retries=max_retries,
|
||||||
|
retry_delay=retry_delay,
|
||||||
|
retry_backoff=retry_backoff,
|
||||||
|
wait_label=wait_label,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
as_binary=False,
|
||||||
|
final_label_on_success=final_label_on_success,
|
||||||
|
progress_origin_ts=progress_origin_ts,
|
||||||
|
monitor_progress=monitor_progress,
|
||||||
|
)
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||||
|
return _validate_or_raise(response_model, raw)
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_op(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
poll_endpoint: ApiEndpoint,
|
||||||
|
*,
|
||||||
|
response_model: Type[M],
|
||||||
|
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||||
|
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||||
|
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||||
|
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
data: Optional[BaseModel] = None,
|
||||||
|
poll_interval: float = 5.0,
|
||||||
|
max_poll_attempts: int = 120,
|
||||||
|
timeout_per_poll: float = 120.0,
|
||||||
|
max_retries_per_poll: int = 3,
|
||||||
|
retry_delay_per_poll: float = 1.0,
|
||||||
|
retry_backoff_per_poll: float = 2.0,
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||||
|
cancel_timeout: float = 10.0,
|
||||||
|
) -> M:
|
||||||
|
raw = await poll_op_raw(
|
||||||
|
cls,
|
||||||
|
poll_endpoint=poll_endpoint,
|
||||||
|
status_extractor=_wrap_model_extractor(response_model, status_extractor),
|
||||||
|
progress_extractor=_wrap_model_extractor(response_model, progress_extractor),
|
||||||
|
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||||
|
completed_statuses=completed_statuses,
|
||||||
|
failed_statuses=failed_statuses,
|
||||||
|
queued_statuses=queued_statuses,
|
||||||
|
data=data,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
max_poll_attempts=max_poll_attempts,
|
||||||
|
timeout_per_poll=timeout_per_poll,
|
||||||
|
max_retries_per_poll=max_retries_per_poll,
|
||||||
|
retry_delay_per_poll=retry_delay_per_poll,
|
||||||
|
retry_backoff_per_poll=retry_backoff_per_poll,
|
||||||
|
estimated_duration=estimated_duration,
|
||||||
|
cancel_endpoint=cancel_endpoint,
|
||||||
|
cancel_timeout=cancel_timeout,
|
||||||
|
)
|
||||||
|
if not isinstance(raw, dict):
|
||||||
|
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
|
||||||
|
return _validate_or_raise(response_model, raw)
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_op_raw(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
endpoint: ApiEndpoint,
|
||||||
|
*,
|
||||||
|
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||||
|
content_type: str = "application/json",
|
||||||
|
timeout: float = 3600.0,
|
||||||
|
multipart_parser: Optional[Callable] = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
retry_backoff: float = 2.0,
|
||||||
|
wait_label: str = "Waiting for server",
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
as_binary: bool = False,
|
||||||
|
final_label_on_success: Optional[str] = "Completed",
|
||||||
|
progress_origin_ts: Optional[float] = None,
|
||||||
|
monitor_progress: bool = True,
|
||||||
|
) -> Union[dict[str, Any], bytes]:
|
||||||
|
"""
|
||||||
|
Make a single network request.
|
||||||
|
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||||
|
- If as_binary=True: returns bytes.
|
||||||
|
"""
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump(exclude_none=True)
|
||||||
|
for k, v in list(data.items()):
|
||||||
|
if isinstance(v, Enum):
|
||||||
|
data[k] = v.value
|
||||||
|
cfg = _RequestConfig(
|
||||||
|
node_cls=cls,
|
||||||
|
endpoint=endpoint,
|
||||||
|
timeout=timeout,
|
||||||
|
content_type=content_type,
|
||||||
|
data=data,
|
||||||
|
files=files,
|
||||||
|
multipart_parser=multipart_parser,
|
||||||
|
max_retries=max_retries,
|
||||||
|
retry_delay=retry_delay,
|
||||||
|
retry_backoff=retry_backoff,
|
||||||
|
wait_label=wait_label,
|
||||||
|
monitor_progress=monitor_progress,
|
||||||
|
estimated_total=estimated_duration,
|
||||||
|
final_label_on_success=final_label_on_success,
|
||||||
|
progress_origin_ts=progress_origin_ts,
|
||||||
|
)
|
||||||
|
return await _request_base(cfg, expect_binary=as_binary)
|
||||||
|
|
||||||
|
|
||||||
|
async def poll_op_raw(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
poll_endpoint: ApiEndpoint,
|
||||||
|
*,
|
||||||
|
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||||
|
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||||
|
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||||
|
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||||
|
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||||
|
poll_interval: float = 5.0,
|
||||||
|
max_poll_attempts: int = 120,
|
||||||
|
timeout_per_poll: float = 120.0,
|
||||||
|
max_retries_per_poll: int = 3,
|
||||||
|
retry_delay_per_poll: float = 1.0,
|
||||||
|
retry_backoff_per_poll: float = 2.0,
|
||||||
|
estimated_duration: Optional[int] = None,
|
||||||
|
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||||
|
cancel_timeout: float = 10.0,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing,
|
||||||
|
checks interruption every second, and calls Cancel endpoint (if provided) on interruption.
|
||||||
|
|
||||||
|
Uses default complete, failed and queued states assumption.
|
||||||
|
|
||||||
|
Returns the final JSON response from the poll endpoint.
|
||||||
|
"""
|
||||||
|
completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses)
|
||||||
|
failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses)
|
||||||
|
queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses)
|
||||||
|
started = time.monotonic()
|
||||||
|
consumed_attempts = 0 # counts only non-queued polls
|
||||||
|
|
||||||
|
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||||
|
last_progress: Optional[int] = None
|
||||||
|
|
||||||
|
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||||
|
stop_ticker = asyncio.Event()
|
||||||
|
|
||||||
|
async def _ticker():
|
||||||
|
"""Emit a UI update every second while polling is in progress."""
|
||||||
|
try:
|
||||||
|
while not stop_ticker.is_set():
|
||||||
|
if is_processing_interrupted():
|
||||||
|
break
|
||||||
|
now = time.monotonic()
|
||||||
|
proc_elapsed = state.base_processing_elapsed + (
|
||||||
|
(now - state.active_since) if state.active_since is not None else 0.0
|
||||||
|
)
|
||||||
|
_display_time_progress(
|
||||||
|
cls,
|
||||||
|
status=state.status_label,
|
||||||
|
elapsed_seconds=int(now - state.started),
|
||||||
|
estimated_total=state.estimated_duration,
|
||||||
|
price=state.price,
|
||||||
|
is_queued=state.is_queued,
|
||||||
|
processing_elapsed_seconds=int(proc_elapsed),
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
except Exception as exc:
|
||||||
|
logging.debug("Polling ticker exited: %s", exc)
|
||||||
|
|
||||||
|
ticker_task = asyncio.create_task(_ticker())
|
||||||
|
try:
|
||||||
|
while consumed_attempts < max_poll_attempts:
|
||||||
|
try:
|
||||||
|
resp_json = await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
poll_endpoint,
|
||||||
|
data=data,
|
||||||
|
timeout=timeout_per_poll,
|
||||||
|
max_retries=max_retries_per_poll,
|
||||||
|
retry_delay=retry_delay_per_poll,
|
||||||
|
retry_backoff=retry_backoff_per_poll,
|
||||||
|
wait_label="Checking",
|
||||||
|
estimated_duration=None,
|
||||||
|
as_binary=False,
|
||||||
|
final_label_on_success=None,
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
if not isinstance(resp_json, dict):
|
||||||
|
raise Exception("Polling endpoint returned non-JSON response.")
|
||||||
|
except ProcessingInterrupted:
|
||||||
|
if cancel_endpoint:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
cancel_endpoint,
|
||||||
|
timeout=cancel_timeout,
|
||||||
|
max_retries=0,
|
||||||
|
wait_label="Cancelling task",
|
||||||
|
estimated_duration=None,
|
||||||
|
as_binary=False,
|
||||||
|
final_label_on_success=None,
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
status = _normalize_status_value(status_extractor(resp_json))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Status extraction failed: %s", e)
|
||||||
|
status = None
|
||||||
|
|
||||||
|
if price_extractor:
|
||||||
|
new_price = price_extractor(resp_json)
|
||||||
|
if new_price is not None:
|
||||||
|
state.price = new_price
|
||||||
|
|
||||||
|
if progress_extractor:
|
||||||
|
new_progress = progress_extractor(resp_json)
|
||||||
|
if new_progress is not None and last_progress != new_progress:
|
||||||
|
progress_bar.update_absolute(new_progress, total=100)
|
||||||
|
last_progress = new_progress
|
||||||
|
|
||||||
|
now_ts = time.monotonic()
|
||||||
|
is_queued = status in queued_states
|
||||||
|
|
||||||
|
if is_queued:
|
||||||
|
if state.active_since is not None: # If we just moved from active -> queued, close the active interval
|
||||||
|
state.base_processing_elapsed += now_ts - state.active_since
|
||||||
|
state.active_since = None
|
||||||
|
else:
|
||||||
|
if state.active_since is None: # If we just moved from queued -> active, open a new active interval
|
||||||
|
state.active_since = now_ts
|
||||||
|
|
||||||
|
state.is_queued = is_queued
|
||||||
|
state.status_label = status or ("Queued" if is_queued else "Processing")
|
||||||
|
if status in completed_states:
|
||||||
|
if state.active_since is not None:
|
||||||
|
state.base_processing_elapsed += now_ts - state.active_since
|
||||||
|
state.active_since = None
|
||||||
|
stop_ticker.set()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await ticker_task
|
||||||
|
|
||||||
|
if progress_bar and last_progress != 100:
|
||||||
|
progress_bar.update_absolute(100, total=100)
|
||||||
|
|
||||||
|
_display_time_progress(
|
||||||
|
cls,
|
||||||
|
status=status if status else "Completed",
|
||||||
|
elapsed_seconds=int(now_ts - started),
|
||||||
|
estimated_total=estimated_duration,
|
||||||
|
price=state.price,
|
||||||
|
is_queued=False,
|
||||||
|
processing_elapsed_seconds=int(state.base_processing_elapsed),
|
||||||
|
)
|
||||||
|
return resp_json
|
||||||
|
|
||||||
|
if status in failed_states:
|
||||||
|
msg = f"Task failed: {json.dumps(resp_json)}"
|
||||||
|
logging.error(msg)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await sleep_with_interrupt(poll_interval, cls, None, None, None)
|
||||||
|
except ProcessingInterrupted:
|
||||||
|
if cancel_endpoint:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await sync_op_raw(
|
||||||
|
cls,
|
||||||
|
cancel_endpoint,
|
||||||
|
timeout=cancel_timeout,
|
||||||
|
max_retries=0,
|
||||||
|
wait_label="Cancelling task",
|
||||||
|
estimated_duration=None,
|
||||||
|
as_binary=False,
|
||||||
|
final_label_on_success=None,
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
if not is_queued:
|
||||||
|
consumed_attempts += 1
|
||||||
|
|
||||||
|
raise Exception(
|
||||||
|
f"Polling timed out after {max_poll_attempts} non-queued attempts "
|
||||||
|
f"(~{int(max_poll_attempts * poll_interval)}s of active polling)."
|
||||||
|
)
|
||||||
|
except ProcessingInterrupted:
|
||||||
|
raise
|
||||||
|
except (LocalNetworkError, ApiServerError):
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"Polling aborted due to error: {e}") from e
|
||||||
|
finally:
|
||||||
|
stop_ticker.set()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await ticker_task
|
||||||
|
|
||||||
|
|
||||||
|
def _display_text(
|
||||||
|
node_cls: type[IO.ComfyNode],
|
||||||
|
text: Optional[str],
|
||||||
|
*,
|
||||||
|
status: Optional[Union[str, int]] = None,
|
||||||
|
price: Optional[float] = None,
|
||||||
|
) -> None:
|
||||||
|
display_lines: list[str] = []
|
||||||
|
if status:
|
||||||
|
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||||
|
if price is not None:
|
||||||
|
display_lines.append(f"Price: ${float(price):,.4f}")
|
||||||
|
if text is not None:
|
||||||
|
display_lines.append(text)
|
||||||
|
if display_lines:
|
||||||
|
PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls))
|
||||||
|
|
||||||
|
|
||||||
|
def _display_time_progress(
|
||||||
|
node_cls: type[IO.ComfyNode],
|
||||||
|
status: Optional[Union[str, int]],
|
||||||
|
elapsed_seconds: int,
|
||||||
|
estimated_total: Optional[int] = None,
|
||||||
|
*,
|
||||||
|
price: Optional[float] = None,
|
||||||
|
is_queued: Optional[bool] = None,
|
||||||
|
processing_elapsed_seconds: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||||
|
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||||
|
remaining = max(0, int(estimated_total) - int(pe))
|
||||||
|
time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)"
|
||||||
|
else:
|
||||||
|
time_line = f"Time elapsed: {int(elapsed_seconds)}s"
|
||||||
|
_display_text(node_cls, time_line, status=status, price=price)
|
||||||
|
|
||||||
|
|
||||||
|
async def _diagnose_connectivity() -> dict[str, bool]:
|
||||||
|
"""Best-effort connectivity diagnostics to distinguish local vs. server issues."""
|
||||||
|
results = {
|
||||||
|
"internet_accessible": False,
|
||||||
|
"api_accessible": False,
|
||||||
|
}
|
||||||
|
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
with contextlib.suppress(ClientError, OSError):
|
||||||
|
async with session.get("https://www.google.com") as resp:
|
||||||
|
results["internet_accessible"] = resp.status < 500
|
||||||
|
if not results["internet_accessible"]:
|
||||||
|
return results
|
||||||
|
|
||||||
|
parsed = urlparse(default_base_url())
|
||||||
|
health_url = f"{parsed.scheme}://{parsed.netloc}/health"
|
||||||
|
with contextlib.suppress(ClientError, OSError):
|
||||||
|
async with session.get(health_url) as resp:
|
||||||
|
results["api_accessible"] = resp.status < 500
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||||
|
"""Normalize (filename, value, content_type)."""
|
||||||
|
if len(t) == 2:
|
||||||
|
return t[0], t[1], "application/octet-stream"
|
||||||
|
if len(t) == 3:
|
||||||
|
return t[0], t[1], t[2]
|
||||||
|
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
params = dict(endpoint_params or {})
|
||||||
|
if method.upper() == "GET" and data:
|
||||||
|
for k, v in data.items():
|
||||||
|
if v is not None:
|
||||||
|
params[k] = v
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def _friendly_http_message(status: int, body: Any) -> str:
|
||||||
|
if status == 401:
|
||||||
|
return "Unauthorized: Please login first to use this node."
|
||||||
|
if status == 402:
|
||||||
|
return "Payment Required: Please add credits to your account to use this node."
|
||||||
|
if status == 409:
|
||||||
|
return "There is a problem with your account. Please contact support@comfy.org."
|
||||||
|
if status == 429:
|
||||||
|
return "Rate Limit Exceeded: Please try again later."
|
||||||
|
try:
|
||||||
|
if isinstance(body, dict):
|
||||||
|
err = body.get("error")
|
||||||
|
if isinstance(err, dict):
|
||||||
|
msg = err.get("message")
|
||||||
|
typ = err.get("type")
|
||||||
|
if msg and typ:
|
||||||
|
return f"API Error: {msg} (Type: {typ})"
|
||||||
|
if msg:
|
||||||
|
return f"API Error: {msg}"
|
||||||
|
return f"API Error: {json.dumps(body)}"
|
||||||
|
else:
|
||||||
|
txt = str(body)
|
||||||
|
if len(txt) <= 200:
|
||||||
|
return f"API Error (raw): {txt}"
|
||||||
|
return f"API Error (status {status})"
|
||||||
|
except Exception:
|
||||||
|
return f"HTTP {status}: Unknown error"
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||||
|
slug = path.strip("/").replace("/", "_") or "op"
|
||||||
|
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
|
||||||
|
def _snapshot_request_body_for_logging(
|
||||||
|
content_type: str,
|
||||||
|
method: str,
|
||||||
|
data: Optional[dict[str, Any]],
|
||||||
|
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||||
|
) -> Optional[Union[dict[str, Any], str]]:
|
||||||
|
if method.upper() == "GET":
|
||||||
|
return None
|
||||||
|
if content_type == "multipart/form-data":
|
||||||
|
form_fields = sorted([k for k, v in (data or {}).items() if v is not None])
|
||||||
|
file_fields: list[dict[str, str]] = []
|
||||||
|
if files:
|
||||||
|
file_iter = files if isinstance(files, list) else list(files.items())
|
||||||
|
for field_name, file_obj in file_iter:
|
||||||
|
if file_obj is None:
|
||||||
|
continue
|
||||||
|
if isinstance(file_obj, tuple):
|
||||||
|
filename = file_obj[0]
|
||||||
|
else:
|
||||||
|
filename = getattr(file_obj, "name", field_name)
|
||||||
|
file_fields.append({"field": field_name, "filename": str(filename or "")})
|
||||||
|
return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields}
|
||||||
|
if content_type == "application/x-www-form-urlencoded":
|
||||||
|
return data or {}
|
||||||
|
return data or {}
|
||||||
|
|
||||||
|
|
||||||
|
async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||||
|
"""Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors."""
|
||||||
|
url = cfg.endpoint.path
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
|
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||||
|
|
||||||
|
method = cfg.endpoint.method
|
||||||
|
params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None)
|
||||||
|
|
||||||
|
async def _monitor(stop_evt: asyncio.Event, start_ts: float):
|
||||||
|
"""Every second: update elapsed time and signal interruption."""
|
||||||
|
try:
|
||||||
|
while not stop_evt.is_set():
|
||||||
|
if is_processing_interrupted():
|
||||||
|
return
|
||||||
|
if cfg.monitor_progress:
|
||||||
|
_display_time_progress(
|
||||||
|
cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return # normal shutdown
|
||||||
|
|
||||||
|
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
|
||||||
|
attempt = 0
|
||||||
|
delay = cfg.retry_delay
|
||||||
|
operation_succeeded: bool = False
|
||||||
|
final_elapsed_seconds: Optional[int] = None
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
stop_event = asyncio.Event()
|
||||||
|
monitor_task: Optional[asyncio.Task] = None
|
||||||
|
sess: Optional[aiohttp.ClientSession] = None
|
||||||
|
|
||||||
|
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||||
|
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||||
|
|
||||||
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
|
if cfg.endpoint.headers:
|
||||||
|
payload_headers.update(cfg.endpoint.headers)
|
||||||
|
|
||||||
|
payload_kw: dict[str, Any] = {"headers": payload_headers}
|
||||||
|
if method == "GET":
|
||||||
|
payload_headers.pop("Content-Type", None)
|
||||||
|
request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files)
|
||||||
|
try:
|
||||||
|
if cfg.monitor_progress:
|
||||||
|
monitor_task = asyncio.create_task(_monitor(stop_event, start_time))
|
||||||
|
|
||||||
|
timeout = aiohttp.ClientTimeout(total=cfg.timeout)
|
||||||
|
sess = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
|
||||||
|
if cfg.content_type == "multipart/form-data" and method != "GET":
|
||||||
|
# aiohttp will set Content-Type boundary; remove any fixed Content-Type
|
||||||
|
payload_headers.pop("Content-Type", None)
|
||||||
|
if cfg.multipart_parser and cfg.data:
|
||||||
|
form = cfg.multipart_parser(cfg.data)
|
||||||
|
if not isinstance(form, aiohttp.FormData):
|
||||||
|
raise ValueError("multipart_parser must return aiohttp.FormData")
|
||||||
|
else:
|
||||||
|
form = aiohttp.FormData(default_to_multipart=True)
|
||||||
|
if cfg.data:
|
||||||
|
for k, v in cfg.data.items():
|
||||||
|
if v is None:
|
||||||
|
continue
|
||||||
|
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v)
|
||||||
|
if cfg.files:
|
||||||
|
file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items()
|
||||||
|
for field_name, file_obj in file_iter:
|
||||||
|
if file_obj is None:
|
||||||
|
continue
|
||||||
|
if isinstance(file_obj, tuple):
|
||||||
|
filename, file_value, content_type = _unpack_tuple(file_obj)
|
||||||
|
else:
|
||||||
|
filename = getattr(file_obj, "name", field_name)
|
||||||
|
file_value = file_obj
|
||||||
|
content_type = "application/octet-stream"
|
||||||
|
# Attempt to rewind BytesIO for retries
|
||||||
|
if isinstance(file_value, BytesIO):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
file_value.seek(0)
|
||||||
|
form.add_field(field_name, file_value, filename=filename, content_type=content_type)
|
||||||
|
payload_kw["data"] = form
|
||||||
|
elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET":
|
||||||
|
payload_headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
payload_kw["data"] = cfg.data or {}
|
||||||
|
elif method != "GET":
|
||||||
|
payload_headers["Content-Type"] = "application/json"
|
||||||
|
payload_kw["json"] = cfg.data or {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
|
request_params=dict(params) if params else None,
|
||||||
|
request_data=request_body_log,
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] request logging failed: %s", _log_e)
|
||||||
|
|
||||||
|
req_coro = sess.request(method, url, params=params, **payload_kw)
|
||||||
|
req_task = asyncio.create_task(req_coro)
|
||||||
|
|
||||||
|
# Race: request vs. monitor (interruption)
|
||||||
|
tasks = {req_task}
|
||||||
|
if monitor_task:
|
||||||
|
tasks.add(monitor_task)
|
||||||
|
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
if monitor_task and monitor_task in done:
|
||||||
|
# Interrupted – cancel the request and abort
|
||||||
|
if req_task in pending:
|
||||||
|
req_task.cancel()
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
|
||||||
|
# Otherwise, request finished
|
||||||
|
resp = await req_task
|
||||||
|
async with resp:
|
||||||
|
if resp.status >= 400:
|
||||||
|
try:
|
||||||
|
body = await resp.json()
|
||||||
|
except (ContentTypeError, json.JSONDecodeError):
|
||||||
|
body = await resp.text()
|
||||||
|
if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
"HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
resp.status,
|
||||||
|
delay,
|
||||||
|
attempt,
|
||||||
|
cfg.max_retries,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=_friendly_http_message(resp.status, body),
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||||
|
|
||||||
|
await sleep_with_interrupt(
|
||||||
|
delay,
|
||||||
|
cfg.node_cls,
|
||||||
|
cfg.wait_label if cfg.monitor_progress else None,
|
||||||
|
start_time if cfg.monitor_progress else None,
|
||||||
|
cfg.estimated_total,
|
||||||
|
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||||
|
)
|
||||||
|
delay *= cfg.retry_backoff
|
||||||
|
continue
|
||||||
|
msg = _friendly_http_message(resp.status, body)
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=msg,
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||||
|
raise Exception(msg)
|
||||||
|
|
||||||
|
if expect_binary:
|
||||||
|
buff = bytearray()
|
||||||
|
last_tick = time.monotonic()
|
||||||
|
async for chunk in resp.content.iter_chunked(64 * 1024):
|
||||||
|
buff.extend(chunk)
|
||||||
|
now = time.monotonic()
|
||||||
|
if now - last_tick >= 1.0:
|
||||||
|
last_tick = now
|
||||||
|
if is_processing_interrupted():
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
if cfg.monitor_progress:
|
||||||
|
_display_time_progress(
|
||||||
|
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||||
|
)
|
||||||
|
bytes_payload = bytes(buff)
|
||||||
|
operation_succeeded = True
|
||||||
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=bytes_payload,
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||||
|
return bytes_payload
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
payload = await resp.json()
|
||||||
|
response_content_to_log: Any = payload
|
||||||
|
except (ContentTypeError, json.JSONDecodeError):
|
||||||
|
text = await resp.text()
|
||||||
|
try:
|
||||||
|
payload = json.loads(text) if text else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
payload = {"_raw": text}
|
||||||
|
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||||
|
operation_succeeded = True
|
||||||
|
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=response_content_to_log,
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] response logging failed: %s", _log_e)
|
||||||
|
return payload
|
||||||
|
|
||||||
|
except ProcessingInterrupted:
|
||||||
|
logging.debug("Polling was interrupted by user")
|
||||||
|
raise
|
||||||
|
except (ClientError, OSError) as e:
|
||||||
|
if attempt <= cfg.max_retries:
|
||||||
|
logging.warning(
|
||||||
|
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
delay,
|
||||||
|
attempt,
|
||||||
|
cfg.max_retries,
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
|
request_params=dict(params) if params else None,
|
||||||
|
request_data=request_body_log,
|
||||||
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] request error logging failed: %s", _log_e)
|
||||||
|
await sleep_with_interrupt(
|
||||||
|
delay,
|
||||||
|
cfg.node_cls,
|
||||||
|
cfg.wait_label if cfg.monitor_progress else None,
|
||||||
|
start_time if cfg.monitor_progress else None,
|
||||||
|
cfg.estimated_total,
|
||||||
|
display_callback=_display_time_progress if cfg.monitor_progress else None,
|
||||||
|
)
|
||||||
|
delay *= cfg.retry_backoff
|
||||||
|
continue
|
||||||
|
diag = await _diagnose_connectivity()
|
||||||
|
if not diag["internet_accessible"]:
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
|
request_params=dict(params) if params else None,
|
||||||
|
request_data=request_body_log,
|
||||||
|
error_message=f"LocalNetworkError: {str(e)}",
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||||
|
raise LocalNetworkError(
|
||||||
|
"Unable to connect to the API server due to local network issues. "
|
||||||
|
"Please check your internet connection and try again."
|
||||||
|
) from e
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method=method,
|
||||||
|
request_url=url,
|
||||||
|
request_headers=dict(payload_headers) if payload_headers else None,
|
||||||
|
request_params=dict(params) if params else None,
|
||||||
|
request_data=request_body_log,
|
||||||
|
error_message=f"ApiServerError: {str(e)}",
|
||||||
|
)
|
||||||
|
except Exception as _log_e:
|
||||||
|
logging.debug("[DEBUG] final error logging failed: %s", _log_e)
|
||||||
|
raise ApiServerError(
|
||||||
|
f"The API server at {default_base_url()} is currently unreachable. "
|
||||||
|
f"The service may be experiencing issues."
|
||||||
|
) from e
|
||||||
|
finally:
|
||||||
|
stop_event.set()
|
||||||
|
if monitor_task:
|
||||||
|
monitor_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await monitor_task
|
||||||
|
if sess:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await sess.close()
|
||||||
|
if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success:
|
||||||
|
_display_time_progress(
|
||||||
|
cfg.node_cls,
|
||||||
|
status=cfg.final_label_on_success,
|
||||||
|
elapsed_seconds=(
|
||||||
|
final_elapsed_seconds
|
||||||
|
if final_elapsed_seconds is not None
|
||||||
|
else int(time.monotonic() - start_time)
|
||||||
|
),
|
||||||
|
estimated_total=cfg.estimated_total,
|
||||||
|
price=None,
|
||||||
|
is_queued=False,
|
||||||
|
processing_elapsed_seconds=final_elapsed_seconds,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||||
|
try:
|
||||||
|
return response_model.model_validate(payload)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(
|
||||||
|
"Response validation failed for %s: %s",
|
||||||
|
getattr(response_model, "__name__", response_model),
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
raise Exception(
|
||||||
|
f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_model_extractor(
|
||||||
|
response_model: Type[M],
|
||||||
|
extractor: Optional[Callable[[M], Any]],
|
||||||
|
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||||
|
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||||
|
Validates the dict into `response_model` before invoking `extractor`.
|
||||||
|
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||||
|
the same response for multiple extractors in a single poll attempt.
|
||||||
|
"""
|
||||||
|
if extractor is None:
|
||||||
|
return None
|
||||||
|
_cache: dict[int, M] = {}
|
||||||
|
|
||||||
|
def _wrapped(d: dict[str, Any]) -> Any:
|
||||||
|
try:
|
||||||
|
key = id(d)
|
||||||
|
model = _cache.get(key)
|
||||||
|
if model is None:
|
||||||
|
model = response_model.model_validate(d)
|
||||||
|
_cache[key] = model
|
||||||
|
return extractor(model)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Extractor failed (typed -> dict wrapper): %s", e)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||||
|
if not values:
|
||||||
|
return set()
|
||||||
|
out: set[Union[str, int]] = set()
|
||||||
|
for v in values:
|
||||||
|
nv = _normalize_status_value(v)
|
||||||
|
if nv is not None:
|
||||||
|
out.add(nv)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||||
|
if isinstance(val, str):
|
||||||
|
return val.strip().lower()
|
||||||
|
return val
|
||||||
14
comfy_api_nodes/util/common_exceptions.py
Normal file
14
comfy_api_nodes/util/common_exceptions.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
class NetworkError(Exception):
|
||||||
|
"""Base exception for network-related errors with diagnostic information."""
|
||||||
|
|
||||||
|
|
||||||
|
class LocalNetworkError(NetworkError):
|
||||||
|
"""Exception raised when local network connectivity issues are detected."""
|
||||||
|
|
||||||
|
|
||||||
|
class ApiServerError(NetworkError):
|
||||||
|
"""Exception raised when the API server is unreachable but internet is working."""
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingInterrupted(Exception):
|
||||||
|
"""Operation was interrupted by user/runtime via processing_interrupted()."""
|
||||||
470
comfy_api_nodes/util/conversions.py
Normal file
470
comfy_api_nodes/util/conversions.py
Normal file
@ -0,0 +1,470 @@
|
|||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import mimetypes
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import av
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from comfy.utils import common_upscale
|
||||||
|
from comfy_api.latest import Input, InputImpl
|
||||||
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
|
|
||||||
|
from ._helpers import mimetype_to_extension
|
||||||
|
|
||||||
|
|
||||||
|
def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
|
||||||
|
"""Converts image data from BytesIO to a torch.Tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_bytesio: BytesIO object containing the image data.
|
||||||
|
mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A torch.Tensor representing the image (1, H, W, C).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PIL.UnidentifiedImageError: If the image data cannot be identified.
|
||||||
|
ValueError: If the specified mode is invalid.
|
||||||
|
"""
|
||||||
|
image = Image.open(image_bytesio)
|
||||||
|
image = image.convert(mode)
|
||||||
|
image_array = np.array(image).astype(np.float32) / 255.0
|
||||||
|
return torch.from_numpy(image_array).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts a pair of image tensors to a batch tensor.
|
||||||
|
If the images are not the same size, the smaller image is resized to
|
||||||
|
match the larger image.
|
||||||
|
"""
|
||||||
|
if image1.shape[1:] != image2.shape[1:]:
|
||||||
|
image2 = common_upscale(
|
||||||
|
image2.movedim(-1, 1),
|
||||||
|
image1.shape[2],
|
||||||
|
image1.shape[1],
|
||||||
|
"bilinear",
|
||||||
|
"center",
|
||||||
|
).movedim(1, -1)
|
||||||
|
return torch.cat((image1, image2), dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_bytesio(
|
||||||
|
image: torch.Tensor,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> BytesIO:
|
||||||
|
"""Converts a torch.Tensor image to a named BytesIO object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Input torch.Tensor image.
|
||||||
|
name: Optional filename for the BytesIO object.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Named BytesIO object containing the image data, with pointer set to the start of buffer.
|
||||||
|
"""
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
pil_image = tensor_to_pil(image, total_pixels=total_pixels)
|
||||||
|
img_binary = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||||
|
img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
|
||||||
|
return img_binary
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
|
||||||
|
"""Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
|
||||||
|
if len(image.shape) > 3:
|
||||||
|
image = image[0]
|
||||||
|
# TODO: remove alpha if not allowed and present
|
||||||
|
input_tensor = image.cpu()
|
||||||
|
input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze()
|
||||||
|
image_np = (input_tensor.numpy() * 255).astype(np.uint8)
|
||||||
|
img = Image.fromarray(image_np)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_base64_string(
|
||||||
|
image_tensor: torch.Tensor,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> str:
|
||||||
|
"""Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor: Input torch.Tensor image.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Base64 encoded string of the image.
|
||||||
|
"""
|
||||||
|
pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels)
|
||||||
|
img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type)
|
||||||
|
img_bytes = img_byte_arr.getvalue()
|
||||||
|
# Encode bytes to base64 string
|
||||||
|
base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
|
||||||
|
return base64_encoded_string
|
||||||
|
|
||||||
|
|
||||||
|
def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
|
||||||
|
"""Converts a PIL Image to a BytesIO object."""
|
||||||
|
if not mime_type:
|
||||||
|
mime_type = "image/png"
|
||||||
|
|
||||||
|
img_byte_arr = BytesIO()
|
||||||
|
# Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
|
||||||
|
pil_format = mime_type.split("/")[-1].upper()
|
||||||
|
if pil_format == "JPG":
|
||||||
|
pil_format = "JPEG"
|
||||||
|
img.save(img_byte_arr, format=pil_format)
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
return img_byte_arr
|
||||||
|
|
||||||
|
|
||||||
|
def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
|
||||||
|
"""Downscale input image tensor to roughly the specified total pixels."""
|
||||||
|
samples = image.movedim(-1, 1)
|
||||||
|
total = int(total_pixels)
|
||||||
|
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
|
||||||
|
if scale_by >= 1:
|
||||||
|
return image
|
||||||
|
width = round(samples.shape[3] * scale_by)
|
||||||
|
height = round(samples.shape[2] * scale_by)
|
||||||
|
|
||||||
|
s = common_upscale(samples, width, height, "lanczos", "disabled")
|
||||||
|
s = s.movedim(1, -1)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def tensor_to_data_uri(
|
||||||
|
image_tensor: torch.Tensor,
|
||||||
|
total_pixels: int = 2048 * 2048,
|
||||||
|
mime_type: str = "image/png",
|
||||||
|
) -> str:
|
||||||
|
"""Converts a tensor image to a Data URI string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_tensor: Input torch.Tensor image.
|
||||||
|
total_pixels: Maximum total pixels for potential downscaling.
|
||||||
|
mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Data URI string (e.g., 'data:image/png;base64,...').
|
||||||
|
"""
|
||||||
|
base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str:
|
||||||
|
"""Converts an audio input to a base64 string."""
|
||||||
|
sample_rate: int = audio["sample_rate"]
|
||||||
|
waveform: torch.Tensor = audio["waveform"]
|
||||||
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||||
|
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||||
|
audio_bytes = audio_bytes_io.getvalue()
|
||||||
|
return base64.b64encode(audio_bytes).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def video_to_base64_string(
|
||||||
|
video: Input.Video,
|
||||||
|
container_format: VideoContainer = None,
|
||||||
|
codec: VideoCodec = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Converts a video input to a base64 string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: The video input to convert
|
||||||
|
container_format: Optional container format to use (defaults to video.container if available)
|
||||||
|
codec: Optional codec to use (defaults to video.codec if available)
|
||||||
|
"""
|
||||||
|
video_bytes_io = BytesIO()
|
||||||
|
|
||||||
|
# Use provided format/codec if specified, otherwise use video's own if available
|
||||||
|
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||||
|
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||||
|
|
||||||
|
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def audio_ndarray_to_bytesio(
|
||||||
|
audio_data_np: np.ndarray,
|
||||||
|
sample_rate: int,
|
||||||
|
container_format: str = "mp4",
|
||||||
|
codec_name: str = "aac",
|
||||||
|
) -> BytesIO:
|
||||||
|
"""
|
||||||
|
Encodes a numpy array of audio data into a BytesIO object.
|
||||||
|
"""
|
||||||
|
audio_bytes_io = BytesIO()
|
||||||
|
with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
|
||||||
|
audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
audio_data_np,
|
||||||
|
format="fltp",
|
||||||
|
layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
|
||||||
|
)
|
||||||
|
frame.sample_rate = sample_rate
|
||||||
|
frame.pts = 0
|
||||||
|
|
||||||
|
for packet in audio_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
# Flush stream
|
||||||
|
for packet in audio_stream.encode(None):
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
audio_bytes_io.seek(0)
|
||||||
|
return audio_bytes_io
|
||||||
|
|
||||||
|
|
||||||
|
def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Prepares audio waveform for av library by converting to a contiguous numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Contiguous numpy array of the audio waveform. If the audio was batched,
|
||||||
|
the first item is taken.
|
||||||
|
"""
|
||||||
|
if waveform.ndim != 3 or waveform.shape[0] != 1:
|
||||||
|
raise ValueError("Expected waveform tensor shape (1, channels, samples)")
|
||||||
|
|
||||||
|
# If batch is > 1, take first item
|
||||||
|
if waveform.shape[0] > 1:
|
||||||
|
waveform = waveform[0]
|
||||||
|
|
||||||
|
# Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
|
||||||
|
audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
|
||||||
|
if audio_data_np.dtype != np.float32:
|
||||||
|
audio_data_np = audio_data_np.astype(np.float32)
|
||||||
|
|
||||||
|
return audio_data_np
|
||||||
|
|
||||||
|
|
||||||
|
def audio_input_to_mp3(audio: Input.Audio) -> BytesIO:
|
||||||
|
waveform = audio["waveform"].cpu()
|
||||||
|
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
output_container = av.open(output_buffer, mode="w", format="mp3")
|
||||||
|
|
||||||
|
out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"])
|
||||||
|
out_stream.bit_rate = 320000
|
||||||
|
|
||||||
|
frame = av.AudioFrame.from_ndarray(
|
||||||
|
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||||
|
format="flt",
|
||||||
|
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||||
|
)
|
||||||
|
frame.sample_rate = audio["sample_rate"]
|
||||||
|
frame.pts = 0
|
||||||
|
output_container.mux(out_stream.encode(frame))
|
||||||
|
output_container.mux(out_stream.encode(None))
|
||||||
|
output_container.close()
|
||||||
|
output_buffer.seek(0)
|
||||||
|
return output_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def trim_video(video: Input.Video, duration_sec: float) -> Input.Video:
|
||||||
|
"""
|
||||||
|
Returns a new VideoInput object trimmed from the beginning to the specified duration,
|
||||||
|
using av to avoid loading entire video into memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
video: Input video to trim
|
||||||
|
duration_sec: Duration in seconds to keep from the beginning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
VideoFromFile object that owns the output buffer
|
||||||
|
"""
|
||||||
|
output_buffer = BytesIO()
|
||||||
|
input_container = None
|
||||||
|
output_container = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the stream source - this avoids loading entire video into memory
|
||||||
|
# when the source is already a file path
|
||||||
|
input_source = video.get_stream_source()
|
||||||
|
|
||||||
|
# Open containers
|
||||||
|
input_container = av.open(input_source, mode="r")
|
||||||
|
output_container = av.open(output_buffer, mode="w", format="mp4")
|
||||||
|
|
||||||
|
# Set up output streams for re-encoding
|
||||||
|
video_stream = None
|
||||||
|
audio_stream = None
|
||||||
|
|
||||||
|
for stream in input_container.streams:
|
||||||
|
logging.info("Found stream: type=%s, class=%s", stream.type, type(stream))
|
||||||
|
if isinstance(stream, av.VideoStream):
|
||||||
|
# Create output video stream with same parameters
|
||||||
|
video_stream = output_container.add_stream("h264", rate=stream.average_rate)
|
||||||
|
video_stream.width = stream.width
|
||||||
|
video_stream.height = stream.height
|
||||||
|
video_stream.pix_fmt = "yuv420p"
|
||||||
|
logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate)
|
||||||
|
elif isinstance(stream, av.AudioStream):
|
||||||
|
# Create output audio stream with same parameters
|
||||||
|
audio_stream = output_container.add_stream("aac", rate=stream.sample_rate)
|
||||||
|
audio_stream.sample_rate = stream.sample_rate
|
||||||
|
audio_stream.layout = stream.layout
|
||||||
|
logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels)
|
||||||
|
|
||||||
|
# Calculate target frame count that's divisible by 16
|
||||||
|
fps = input_container.streams.video[0].average_rate
|
||||||
|
estimated_frames = int(duration_sec * fps)
|
||||||
|
target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16
|
||||||
|
|
||||||
|
if target_frames == 0:
|
||||||
|
raise ValueError("Video too short: need at least 16 frames for Moonvalley")
|
||||||
|
|
||||||
|
frame_count = 0
|
||||||
|
audio_frame_count = 0
|
||||||
|
|
||||||
|
# Decode and re-encode video frames
|
||||||
|
if video_stream:
|
||||||
|
for frame in input_container.decode(video=0):
|
||||||
|
if frame_count >= target_frames:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Re-encode frame
|
||||||
|
for packet in video_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
frame_count += 1
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
for packet in video_stream.encode():
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames)
|
||||||
|
|
||||||
|
# Decode and re-encode audio frames
|
||||||
|
if audio_stream:
|
||||||
|
input_container.seek(0) # Reset to beginning for audio
|
||||||
|
for frame in input_container.decode(audio=0):
|
||||||
|
if frame.time >= duration_sec:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Re-encode frame
|
||||||
|
for packet in audio_stream.encode(frame):
|
||||||
|
output_container.mux(packet)
|
||||||
|
audio_frame_count += 1
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
for packet in audio_stream.encode():
|
||||||
|
output_container.mux(packet)
|
||||||
|
|
||||||
|
logging.info("Encoded %s audio frames", audio_frame_count)
|
||||||
|
|
||||||
|
# Close containers
|
||||||
|
output_container.close()
|
||||||
|
input_container.close()
|
||||||
|
|
||||||
|
# Return as VideoFromFile using the buffer
|
||||||
|
output_buffer.seek(0)
|
||||||
|
return InputImpl.VideoFromFile(output_buffer)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up on error
|
||||||
|
if input_container is not None:
|
||||||
|
input_container.close()
|
||||||
|
if output_container is not None:
|
||||||
|
output_container.close()
|
||||||
|
raise RuntimeError(f"Failed to trim video: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def _f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file."""
|
||||||
|
if wav.dtype.is_floating_point:
|
||||||
|
return wav
|
||||||
|
elif wav.dtype == torch.int16:
|
||||||
|
return wav.float() / (2**15)
|
||||||
|
elif wav.dtype == torch.int32:
|
||||||
|
return wav.float() / (2**31)
|
||||||
|
raise ValueError(f"Unsupported wav dtype: {wav.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict:
|
||||||
|
"""
|
||||||
|
Decode any common audio container from bytes using PyAV and return
|
||||||
|
a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}.
|
||||||
|
"""
|
||||||
|
with av.open(BytesIO(audio_bytes)) as af:
|
||||||
|
if not af.streams.audio:
|
||||||
|
raise ValueError("No audio stream found in response.")
|
||||||
|
stream = af.streams.audio[0]
|
||||||
|
|
||||||
|
in_sr = int(stream.codec_context.sample_rate)
|
||||||
|
out_sr = in_sr
|
||||||
|
|
||||||
|
frames: list[torch.Tensor] = []
|
||||||
|
n_channels = stream.channels or 1
|
||||||
|
|
||||||
|
for frame in af.decode(streams=stream.index):
|
||||||
|
arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T]
|
||||||
|
buf = torch.from_numpy(arr)
|
||||||
|
if buf.ndim == 1:
|
||||||
|
buf = buf.unsqueeze(0) # [T] -> [1, T]
|
||||||
|
elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels:
|
||||||
|
buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T]
|
||||||
|
elif buf.shape[0] != n_channels:
|
||||||
|
buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T]
|
||||||
|
frames.append(buf)
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise ValueError("Decoded zero audio frames.")
|
||||||
|
|
||||||
|
wav = torch.cat(frames, dim=1) # [C, T]
|
||||||
|
wav = _f32_pcm(wav)
|
||||||
|
return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr}
|
||||||
|
|
||||||
|
|
||||||
|
def resize_mask_to_image(
|
||||||
|
mask: torch.Tensor,
|
||||||
|
image: torch.Tensor,
|
||||||
|
upscale_method="nearest-exact",
|
||||||
|
crop="disabled",
|
||||||
|
allow_gradient=True,
|
||||||
|
add_channel_dim=False,
|
||||||
|
):
|
||||||
|
"""Resize mask to be the same dimensions as an image, while maintaining proper format for API calls."""
|
||||||
|
_, height, width, _ = image.shape
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
mask = mask.movedim(-1, 1)
|
||||||
|
mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop)
|
||||||
|
mask = mask.movedim(1, -1)
|
||||||
|
if not add_channel_dim:
|
||||||
|
mask = mask.squeeze(-1)
|
||||||
|
if not allow_gradient:
|
||||||
|
mask = (mask > 0.5).float()
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_base64_string(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a base64 string."""
|
||||||
|
with open(filepath, "rb") as f:
|
||||||
|
file_content = f.read()
|
||||||
|
return base64.b64encode(file_content).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def text_filepath_to_data_uri(filepath: str) -> str:
|
||||||
|
"""Converts a text file to a data URI."""
|
||||||
|
base64_string = text_filepath_to_base64_string(filepath)
|
||||||
|
mime_type, _ = mimetypes.guess_type(filepath)
|
||||||
|
if mime_type is None:
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
return f"data:{mime_type};base64,{base64_string}"
|
||||||
262
comfy_api_nodes/util/download_helpers.py
Normal file
262
comfy_api_nodes/util/download_helpers.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import IO, Optional, Union
|
||||||
|
from urllib.parse import urljoin, urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import torch
|
||||||
|
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||||
|
|
||||||
|
from comfy_api.input_impl import VideoFromFile
|
||||||
|
from comfy_api.latest import IO as COMFY_IO
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
|
from ._helpers import (
|
||||||
|
default_base_url,
|
||||||
|
get_auth_header,
|
||||||
|
is_processing_interrupted,
|
||||||
|
sleep_with_interrupt,
|
||||||
|
)
|
||||||
|
from .client import _diagnose_connectivity
|
||||||
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||||
|
from .conversions import bytesio_to_image_tensor
|
||||||
|
|
||||||
|
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||||
|
|
||||||
|
|
||||||
|
async def download_url_to_bytesio(
|
||||||
|
url: str,
|
||||||
|
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
max_retries: int = 5,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
retry_backoff: float = 2.0,
|
||||||
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Stream-download a URL to `dest`.
|
||||||
|
|
||||||
|
`dest` must be one of:
|
||||||
|
- a BytesIO (rewound to 0 after write),
|
||||||
|
- a file-like object opened in binary write mode (must implement .write()),
|
||||||
|
- a filesystem path (str | pathlib.Path), which will be opened with 'wb'.
|
||||||
|
|
||||||
|
If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded
|
||||||
|
to an absolute URL and authentication headers can be applied.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors)
|
||||||
|
"""
|
||||||
|
if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"):
|
||||||
|
raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().")
|
||||||
|
|
||||||
|
attempt = 0
|
||||||
|
delay = retry_delay
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
|
if cls is None:
|
||||||
|
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||||
|
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||||
|
headers = get_auth_header(cls)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
op_id = _generate_operation_id("GET", url, attempt)
|
||||||
|
timeout_cfg = aiohttp.ClientTimeout(total=timeout)
|
||||||
|
|
||||||
|
is_path_sink = isinstance(dest, (str, Path))
|
||||||
|
fhandle = None
|
||||||
|
session: Optional[aiohttp.ClientSession] = None
|
||||||
|
stop_evt: Optional[asyncio.Event] = None
|
||||||
|
monitor_task: Optional[asyncio.Task] = None
|
||||||
|
req_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url)
|
||||||
|
|
||||||
|
session = aiohttp.ClientSession(timeout=timeout_cfg)
|
||||||
|
stop_evt = asyncio.Event()
|
||||||
|
|
||||||
|
async def _monitor():
|
||||||
|
try:
|
||||||
|
while not stop_evt.is_set():
|
||||||
|
if is_processing_interrupted():
|
||||||
|
return
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
|
||||||
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
|
|
||||||
|
req_task = asyncio.create_task(session.get(url, headers=headers))
|
||||||
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
if monitor_task in done and req_task in pending:
|
||||||
|
req_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await req_task
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await req_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
|
|
||||||
|
async with resp:
|
||||||
|
if resp.status >= 400:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
try:
|
||||||
|
body = await resp.json()
|
||||||
|
except (ContentTypeError, ValueError):
|
||||||
|
text = await resp.text()
|
||||||
|
body = text if len(text) <= 4096 else f"[text {len(text)} bytes]"
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=op_id,
|
||||||
|
request_method="GET",
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=f"HTTP {resp.status}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if resp.status in _RETRY_STATUS and attempt <= max_retries:
|
||||||
|
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||||
|
delay *= retry_backoff
|
||||||
|
continue
|
||||||
|
raise Exception(f"Failed to download (HTTP {resp.status}).")
|
||||||
|
|
||||||
|
if is_path_sink:
|
||||||
|
p = Path(str(dest))
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
fhandle = open(p, "wb")
|
||||||
|
sink = fhandle
|
||||||
|
else:
|
||||||
|
sink = dest # BytesIO or file-like
|
||||||
|
|
||||||
|
written = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
chunk = b""
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
|
|
||||||
|
if is_processing_interrupted():
|
||||||
|
raise ProcessingInterrupted("Task cancelled")
|
||||||
|
|
||||||
|
if not chunk:
|
||||||
|
if resp.content.at_eof():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
sink.write(chunk)
|
||||||
|
written += len(chunk)
|
||||||
|
|
||||||
|
if isinstance(dest, BytesIO):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
dest.seek(0)
|
||||||
|
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=op_id,
|
||||||
|
request_method="GET",
|
||||||
|
request_url=url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=f"[streamed {written} bytes to dest]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
|
except (ClientError, OSError) as e:
|
||||||
|
if attempt <= max_retries:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=op_id,
|
||||||
|
request_method="GET",
|
||||||
|
request_url=url,
|
||||||
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
|
)
|
||||||
|
await sleep_with_interrupt(delay, cls, None, None, None)
|
||||||
|
delay *= retry_backoff
|
||||||
|
continue
|
||||||
|
|
||||||
|
diag = await _diagnose_connectivity()
|
||||||
|
if not diag["internet_accessible"]:
|
||||||
|
raise LocalNetworkError(
|
||||||
|
"Unable to connect to the network. Please check your internet connection and try again."
|
||||||
|
) from e
|
||||||
|
raise ApiServerError("The remote service appears unreachable at this time.") from e
|
||||||
|
finally:
|
||||||
|
if stop_evt is not None:
|
||||||
|
stop_evt.set()
|
||||||
|
if monitor_task:
|
||||||
|
monitor_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await monitor_task
|
||||||
|
if req_task and not req_task.done():
|
||||||
|
req_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await req_task
|
||||||
|
if session:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await session.close()
|
||||||
|
if fhandle:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
fhandle.flush()
|
||||||
|
fhandle.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def download_url_to_image_tensor(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
timeout: float = None,
|
||||||
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
|
||||||
|
result = BytesIO()
|
||||||
|
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||||
|
return bytesio_to_image_tensor(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_url_to_video_output(
|
||||||
|
video_url: str,
|
||||||
|
*,
|
||||||
|
timeout: float = None,
|
||||||
|
max_retries: int = 5,
|
||||||
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
|
) -> VideoFromFile:
|
||||||
|
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||||
|
result = BytesIO()
|
||||||
|
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||||
|
return VideoFromFile(result)
|
||||||
|
|
||||||
|
|
||||||
|
async def download_url_as_bytesio(
|
||||||
|
url: str,
|
||||||
|
*,
|
||||||
|
timeout: float = None,
|
||||||
|
cls: type[COMFY_IO.ComfyNode] = None,
|
||||||
|
) -> BytesIO:
|
||||||
|
"""Downloads content from a URL and returns a new BytesIO (rewound to 0)."""
|
||||||
|
result = BytesIO()
|
||||||
|
await download_url_to_bytesio(url, result, timeout=timeout, cls=cls)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_operation_id(method: str, url: str, attempt: int) -> str:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_")
|
||||||
|
except Exception:
|
||||||
|
slug = "download"
|
||||||
|
return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}"
|
||||||
@ -1,11 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
@ -21,7 +21,7 @@ def get_log_directory():
|
|||||||
try:
|
try:
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error creating API log directory {log_dir}: {e}")
|
logger.error("Error creating API log directory %s: %s", log_dir, str(e))
|
||||||
# Fallback to base temp directory if sub-directory creation fails
|
# Fallback to base temp directory if sub-directory creation fails
|
||||||
return base_temp_dir
|
return base_temp_dir
|
||||||
return log_dir
|
return log_dir
|
||||||
@ -122,9 +122,9 @@ def log_request_response(
|
|||||||
try:
|
try:
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
f.write("\n".join(log_content))
|
f.write("\n".join(log_content))
|
||||||
logger.debug(f"API log saved to: {filepath}")
|
logger.debug("API log saved to: %s", filepath)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error writing API log to {filepath}: {e}")
|
logger.error("Error writing API log to %s: %s", filepath, str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
338
comfy_api_nodes/util/upload_helpers.py
Normal file
338
comfy_api_nodes/util/upload_helpers.py
Normal file
@ -0,0 +1,338 @@
|
|||||||
|
import asyncio
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Optional, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, Input
|
||||||
|
from comfy_api.util import VideoCodec, VideoContainer
|
||||||
|
|
||||||
|
from . import request_logger
|
||||||
|
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||||
|
from .client import (
|
||||||
|
ApiEndpoint,
|
||||||
|
_diagnose_connectivity,
|
||||||
|
_display_time_progress,
|
||||||
|
sync_op,
|
||||||
|
)
|
||||||
|
from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted
|
||||||
|
from .conversions import (
|
||||||
|
audio_ndarray_to_bytesio,
|
||||||
|
audio_tensor_to_contiguous_ndarray,
|
||||||
|
tensor_to_bytesio,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadRequest(BaseModel):
|
||||||
|
file_name: str = Field(..., description="Filename to upload")
|
||||||
|
content_type: Optional[str] = Field(
|
||||||
|
None,
|
||||||
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadResponse(BaseModel):
|
||||||
|
download_url: str = Field(..., description="URL to GET uploaded file")
|
||||||
|
upload_url: str = Field(..., description="URL to PUT file to upload")
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_images_to_comfyapi(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
image: torch.Tensor,
|
||||||
|
*,
|
||||||
|
max_images: int = 8,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
wait_label: Optional[str] = "Uploading",
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Uploads images to ComfyUI API and returns download URLs.
|
||||||
|
To upload multiple images, stack them in the batch dimension first.
|
||||||
|
"""
|
||||||
|
# if batch, try to upload each file if max_images is greater than 0
|
||||||
|
download_urls: list[str] = []
|
||||||
|
is_batch = len(image.shape) > 3
|
||||||
|
batch_len = image.shape[0] if is_batch else 1
|
||||||
|
|
||||||
|
for idx in range(min(batch_len, max_images)):
|
||||||
|
tensor = image[idx] if is_batch else image
|
||||||
|
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||||
|
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||||
|
download_urls.append(url)
|
||||||
|
return download_urls
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_audio_to_comfyapi(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
audio: Input.Audio,
|
||||||
|
*,
|
||||||
|
container_format: str = "mp4",
|
||||||
|
codec_name: str = "aac",
|
||||||
|
mime_type: str = "audio/mp4",
|
||||||
|
filename: str = "uploaded_audio.mp4",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Uploads a single audio input to ComfyUI API and returns its download URL.
|
||||||
|
Encodes the raw waveform into the specified format before uploading.
|
||||||
|
"""
|
||||||
|
sample_rate: int = audio["sample_rate"]
|
||||||
|
waveform: torch.Tensor = audio["waveform"]
|
||||||
|
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
|
||||||
|
audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name)
|
||||||
|
return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_video_to_comfyapi(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
video: Input.Video,
|
||||||
|
*,
|
||||||
|
container: VideoContainer = VideoContainer.MP4,
|
||||||
|
codec: VideoCodec = VideoCodec.H264,
|
||||||
|
max_duration: Optional[int] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Uploads a single video to ComfyUI API and returns its download URL.
|
||||||
|
Uses the specified container and codec for saving the video before upload.
|
||||||
|
"""
|
||||||
|
if max_duration is not None:
|
||||||
|
try:
|
||||||
|
actual_duration = video.get_duration()
|
||||||
|
if actual_duration > max_duration:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error getting video duration: %s", str(e))
|
||||||
|
raise ValueError(f"Could not verify video duration from source: {e}") from e
|
||||||
|
|
||||||
|
upload_mime_type = f"video/{container.value.lower()}"
|
||||||
|
filename = f"uploaded_video.{container.value.lower()}"
|
||||||
|
|
||||||
|
# Convert VideoInput to BytesIO using specified container/codec
|
||||||
|
video_bytes_io = BytesIO()
|
||||||
|
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||||
|
video_bytes_io.seek(0)
|
||||||
|
|
||||||
|
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_file_to_comfyapi(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
file_bytes_io: BytesIO,
|
||||||
|
filename: str,
|
||||||
|
upload_mime_type: Optional[str],
|
||||||
|
wait_label: Optional[str] = "Uploading",
|
||||||
|
) -> str:
|
||||||
|
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||||
|
if upload_mime_type is None:
|
||||||
|
request_object = UploadRequest(file_name=filename)
|
||||||
|
else:
|
||||||
|
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
|
||||||
|
create_resp = await sync_op(
|
||||||
|
cls,
|
||||||
|
endpoint=ApiEndpoint(path="/customers/storage", method="POST"),
|
||||||
|
data=request_object,
|
||||||
|
response_model=UploadResponse,
|
||||||
|
final_label_on_success=None,
|
||||||
|
monitor_progress=False,
|
||||||
|
)
|
||||||
|
await upload_file(
|
||||||
|
cls,
|
||||||
|
create_resp.upload_url,
|
||||||
|
file_bytes_io,
|
||||||
|
content_type=upload_mime_type,
|
||||||
|
wait_label=wait_label,
|
||||||
|
)
|
||||||
|
return create_resp.download_url
|
||||||
|
|
||||||
|
|
||||||
|
async def upload_file(
|
||||||
|
cls: type[IO.ComfyNode],
|
||||||
|
upload_url: str,
|
||||||
|
file: Union[BytesIO, str],
|
||||||
|
*,
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
max_retries: int = 3,
|
||||||
|
retry_delay: float = 1.0,
|
||||||
|
retry_backoff: float = 2.0,
|
||||||
|
wait_label: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cls: Node class (provides auth context + UI progress hooks).
|
||||||
|
upload_url: Pre-signed PUT URL.
|
||||||
|
file: BytesIO or path string.
|
||||||
|
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||||
|
max_retries: Maximum retry attempts.
|
||||||
|
retry_delay: Initial delay in seconds.
|
||||||
|
retry_backoff: Exponential backoff factor.
|
||||||
|
wait_label: Progress label shown in Comfy UI.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||||
|
"""
|
||||||
|
if isinstance(file, BytesIO):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
file.seek(0)
|
||||||
|
data = file.read()
|
||||||
|
elif isinstance(file, str):
|
||||||
|
with open(file, "rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
else:
|
||||||
|
raise ValueError("file must be a BytesIO or a filesystem path string")
|
||||||
|
|
||||||
|
headers: dict[str, str] = {}
|
||||||
|
skip_auto_headers: set[str] = set()
|
||||||
|
if content_type:
|
||||||
|
headers["Content-Type"] = content_type
|
||||||
|
else:
|
||||||
|
skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request
|
||||||
|
|
||||||
|
attempt = 0
|
||||||
|
delay = retry_delay
|
||||||
|
start_ts = time.monotonic()
|
||||||
|
op_uuid = uuid.uuid4().hex[:8]
|
||||||
|
while True:
|
||||||
|
attempt += 1
|
||||||
|
operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid)
|
||||||
|
timeout = aiohttp.ClientTimeout(total=None)
|
||||||
|
stop_evt = asyncio.Event()
|
||||||
|
|
||||||
|
async def _monitor():
|
||||||
|
try:
|
||||||
|
while not stop_evt.is_set():
|
||||||
|
if is_processing_interrupted():
|
||||||
|
return
|
||||||
|
if wait_label:
|
||||||
|
_display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None)
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
return
|
||||||
|
|
||||||
|
monitor_task = asyncio.create_task(_monitor())
|
||||||
|
sess: Optional[aiohttp.ClientSession] = None
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method="PUT",
|
||||||
|
request_url=upload_url,
|
||||||
|
request_headers=headers or None,
|
||||||
|
request_params=None,
|
||||||
|
request_data=f"[File data {len(data)} bytes]",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug("[DEBUG] upload request logging failed: %s", e)
|
||||||
|
|
||||||
|
sess = aiohttp.ClientSession(timeout=timeout)
|
||||||
|
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
|
||||||
|
req_task = asyncio.create_task(req)
|
||||||
|
|
||||||
|
done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
|
||||||
|
if monitor_task in done and req_task in pending:
|
||||||
|
req_task.cancel()
|
||||||
|
raise ProcessingInterrupted("Upload cancelled")
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = await req_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise ProcessingInterrupted("Upload cancelled") from None
|
||||||
|
|
||||||
|
async with resp:
|
||||||
|
if resp.status >= 400:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
try:
|
||||||
|
body = await resp.json()
|
||||||
|
except Exception:
|
||||||
|
body = await resp.text()
|
||||||
|
msg = f"Upload failed with status {resp.status}"
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method="PUT",
|
||||||
|
request_url=upload_url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content=body,
|
||||||
|
error_message=msg,
|
||||||
|
)
|
||||||
|
if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries:
|
||||||
|
await sleep_with_interrupt(
|
||||||
|
delay,
|
||||||
|
cls,
|
||||||
|
wait_label,
|
||||||
|
start_ts,
|
||||||
|
None,
|
||||||
|
display_callback=_display_time_progress if wait_label else None,
|
||||||
|
)
|
||||||
|
delay *= retry_backoff
|
||||||
|
continue
|
||||||
|
raise Exception(f"Failed to upload (HTTP {resp.status}).")
|
||||||
|
try:
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method="PUT",
|
||||||
|
request_url=upload_url,
|
||||||
|
response_status_code=resp.status,
|
||||||
|
response_headers=dict(resp.headers),
|
||||||
|
response_content="File uploaded successfully.",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug("[DEBUG] upload response logging failed: %s", e)
|
||||||
|
return
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise ProcessingInterrupted("Task cancelled") from None
|
||||||
|
except (aiohttp.ClientError, OSError) as e:
|
||||||
|
if attempt <= max_retries:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
request_logger.log_request_response(
|
||||||
|
operation_id=operation_id,
|
||||||
|
request_method="PUT",
|
||||||
|
request_url=upload_url,
|
||||||
|
request_headers=headers or None,
|
||||||
|
request_data=f"[File data {len(data)} bytes]",
|
||||||
|
error_message=f"{type(e).__name__}: {str(e)} (will retry)",
|
||||||
|
)
|
||||||
|
await sleep_with_interrupt(
|
||||||
|
delay,
|
||||||
|
cls,
|
||||||
|
wait_label,
|
||||||
|
start_ts,
|
||||||
|
None,
|
||||||
|
display_callback=_display_time_progress if wait_label else None,
|
||||||
|
)
|
||||||
|
delay *= retry_backoff
|
||||||
|
continue
|
||||||
|
|
||||||
|
diag = await _diagnose_connectivity()
|
||||||
|
if not diag["internet_accessible"]:
|
||||||
|
raise LocalNetworkError(
|
||||||
|
"Unable to connect to the network. Please check your internet connection and try again."
|
||||||
|
) from e
|
||||||
|
raise ApiServerError("The API service appears unreachable at this time.") from e
|
||||||
|
finally:
|
||||||
|
stop_evt.set()
|
||||||
|
if monitor_task:
|
||||||
|
monitor_task.cancel()
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await monitor_task
|
||||||
|
if sess:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
await sess.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_")
|
||||||
|
except Exception:
|
||||||
|
slug = "upload"
|
||||||
|
return f"{method}_{slug}_{op_uuid}_try{attempt}"
|
||||||
@ -2,6 +2,8 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy_api.input.video_types import VideoInput
|
||||||
from comfy_api.latest import Input
|
from comfy_api.latest import Input
|
||||||
|
|
||||||
|
|
||||||
@ -28,76 +30,69 @@ def validate_image_dimensions(
|
|||||||
if max_width is not None and width > max_width:
|
if max_width is not None and width > max_width:
|
||||||
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
|
||||||
if min_height is not None and height < min_height:
|
if min_height is not None and height < min_height:
|
||||||
raise ValueError(
|
raise ValueError(f"Image height must be at least {min_height}px, got {height}px")
|
||||||
f"Image height must be at least {min_height}px, got {height}px"
|
|
||||||
)
|
|
||||||
if max_height is not None and height > max_height:
|
if max_height is not None and height > max_height:
|
||||||
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
|
||||||
|
|
||||||
|
|
||||||
def validate_image_aspect_ratio(
|
def validate_image_aspect_ratio(
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
min_aspect_ratio: Optional[float] = None,
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
max_aspect_ratio: Optional[float] = None,
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
):
|
|
||||||
width, height = get_image_dimensions(image)
|
|
||||||
aspect_ratio = width / height
|
|
||||||
|
|
||||||
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
|
|
||||||
)
|
|
||||||
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
|
|
||||||
raise ValueError(
|
|
||||||
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_image_aspect_ratio_range(
|
|
||||||
image: torch.Tensor,
|
|
||||||
min_ratio: tuple[float, float], # e.g. (1, 4)
|
|
||||||
max_ratio: tuple[float, float], # e.g. (4, 1)
|
|
||||||
*,
|
*,
|
||||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||||
) -> float:
|
) -> float:
|
||||||
a1, b1 = min_ratio
|
"""Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked."""
|
||||||
a2, b2 = max_ratio
|
|
||||||
if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0:
|
|
||||||
raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).")
|
|
||||||
lo, hi = (a1 / b1), (a2 / b2)
|
|
||||||
if lo > hi:
|
|
||||||
lo, hi = hi, lo
|
|
||||||
a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text
|
|
||||||
w, h = get_image_dimensions(image)
|
w, h = get_image_dimensions(image)
|
||||||
if w <= 0 or h <= 0:
|
if w <= 0 or h <= 0:
|
||||||
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
raise ValueError(f"Invalid image dimensions: {w}x{h}")
|
||||||
ar = w / h
|
ar = w / h
|
||||||
ok = (lo < ar < hi) if strict else (lo <= ar <= hi)
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
if not ok:
|
|
||||||
op = "<" if strict else "≤"
|
|
||||||
raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}")
|
|
||||||
return ar
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_aspect_ratio_closeness(
|
def validate_images_aspect_ratio_closeness(
|
||||||
start_img,
|
first_image: torch.Tensor,
|
||||||
end_img,
|
second_image: torch.Tensor,
|
||||||
min_rel: float,
|
min_rel: float, # e.g. 0.8
|
||||||
max_rel: float,
|
max_rel: float, # e.g. 1.25
|
||||||
*,
|
*,
|
||||||
strict: bool = False, # True => exclusive, False => inclusive
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
) -> None:
|
) -> float:
|
||||||
w1, h1 = get_image_dimensions(start_img)
|
"""
|
||||||
w2, h2 = get_image_dimensions(end_img)
|
Validates that the two images' aspect ratios are 'close'.
|
||||||
|
The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1).
|
||||||
|
We require C <= limit, where limit = max(max_rel, 1.0 / min_rel).
|
||||||
|
|
||||||
|
Returns the computed closeness factor C.
|
||||||
|
"""
|
||||||
|
w1, h1 = get_image_dimensions(first_image)
|
||||||
|
w2, h2 = get_image_dimensions(second_image)
|
||||||
if min(w1, h1, w2, h2) <= 0:
|
if min(w1, h1, w2, h2) <= 0:
|
||||||
raise ValueError("Invalid image dimensions")
|
raise ValueError("Invalid image dimensions")
|
||||||
ar1 = w1 / h1
|
ar1 = w1 / h1
|
||||||
ar2 = w2 / h2
|
ar2 = w2 / h2
|
||||||
# Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1)
|
|
||||||
closeness = max(ar1, ar2) / min(ar1, ar2)
|
closeness = max(ar1, ar2) / min(ar1, ar2)
|
||||||
limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25
|
limit = max(max_rel, 1.0 / min_rel)
|
||||||
if (closeness >= limit) if strict else (closeness > limit):
|
if (closeness >= limit) if strict else (closeness > limit):
|
||||||
raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.")
|
raise ValueError(
|
||||||
|
f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, "
|
||||||
|
f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})."
|
||||||
|
)
|
||||||
|
return closeness
|
||||||
|
|
||||||
|
|
||||||
|
def validate_aspect_ratio_string(
|
||||||
|
aspect_ratio: str,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||||
|
*,
|
||||||
|
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||||
|
) -> float:
|
||||||
|
"""Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio."""
|
||||||
|
ar = _parse_aspect_ratio_string(aspect_ratio)
|
||||||
|
_assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict)
|
||||||
|
return ar
|
||||||
|
|
||||||
|
|
||||||
def validate_video_dimensions(
|
def validate_video_dimensions(
|
||||||
@ -118,9 +113,7 @@ def validate_video_dimensions(
|
|||||||
if max_width is not None and width > max_width:
|
if max_width is not None and width > max_width:
|
||||||
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
|
||||||
if min_height is not None and height < min_height:
|
if min_height is not None and height < min_height:
|
||||||
raise ValueError(
|
raise ValueError(f"Video height must be at least {min_height}px, got {height}px")
|
||||||
f"Video height must be at least {min_height}px, got {height}px"
|
|
||||||
)
|
|
||||||
if max_height is not None and height > max_height:
|
if max_height is not None and height > max_height:
|
||||||
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
|
||||||
|
|
||||||
@ -138,13 +131,9 @@ def validate_video_duration(
|
|||||||
|
|
||||||
epsilon = 0.0001
|
epsilon = 0.0001
|
||||||
if min_duration is not None and min_duration - epsilon > duration:
|
if min_duration is not None and min_duration - epsilon > duration:
|
||||||
raise ValueError(
|
raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s")
|
||||||
f"Video duration must be at least {min_duration}s, got {duration}s"
|
|
||||||
)
|
|
||||||
if max_duration is not None and duration > max_duration + epsilon:
|
if max_duration is not None and duration > max_duration + epsilon:
|
||||||
raise ValueError(
|
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||||
f"Video duration must be at most {max_duration}s, got {duration}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_number_of_images(images):
|
def get_number_of_images(images):
|
||||||
@ -165,3 +154,77 @@ def validate_audio_duration(
|
|||||||
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s")
|
||||||
if max_duration is not None and dur - eps > max_duration:
|
if max_duration is not None and dur - eps > max_duration:
|
||||||
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_string(
|
||||||
|
string: str,
|
||||||
|
strip_whitespace=True,
|
||||||
|
field_name="prompt",
|
||||||
|
min_length=None,
|
||||||
|
max_length=None,
|
||||||
|
):
|
||||||
|
if string is None:
|
||||||
|
raise Exception(f"Field '{field_name}' cannot be empty.")
|
||||||
|
if strip_whitespace:
|
||||||
|
string = string.strip()
|
||||||
|
if min_length and len(string) < min_length:
|
||||||
|
raise Exception(
|
||||||
|
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
|
||||||
|
)
|
||||||
|
if max_length and len(string) > max_length:
|
||||||
|
raise Exception(
|
||||||
|
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||||
|
"""Validates video container format is MP4."""
|
||||||
|
container_format = video.get_container_format()
|
||||||
|
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||||
|
raise ValueError(f"Only MP4 container format supported. Got: {container_format}")
|
||||||
|
|
||||||
|
|
||||||
|
def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||||
|
a, b = r
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Ratios must be positive, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_ratio_bounds(
|
||||||
|
ar: float,
|
||||||
|
*,
|
||||||
|
min_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
max_ratio: Optional[tuple[float, float]] = None,
|
||||||
|
strict: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||||
|
lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None
|
||||||
|
hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None
|
||||||
|
|
||||||
|
if lo is not None and hi is not None and lo > hi:
|
||||||
|
lo, hi = hi, lo # normalize order if caller swapped them
|
||||||
|
|
||||||
|
if lo is not None:
|
||||||
|
if (ar <= lo) if strict else (ar < lo):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.")
|
||||||
|
if hi is not None:
|
||||||
|
if (ar >= hi) if strict else (ar > hi):
|
||||||
|
op = "<" if strict else "≤"
|
||||||
|
raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.")
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_aspect_ratio_string(ar_str: str) -> float:
|
||||||
|
"""Parse 'X:Y' with integer parts into a positive float ratio X/Y."""
|
||||||
|
parts = ar_str.split(":")
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.")
|
||||||
|
try:
|
||||||
|
a = int(parts[0].strip())
|
||||||
|
b = int(parts[1].strip())
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc
|
||||||
|
if a <= 0 or b <= 0:
|
||||||
|
raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.")
|
||||||
|
return a / b
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
|
import bisect
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -48,7 +53,7 @@ class Unhashable:
|
|||||||
def to_hashable(obj):
|
def to_hashable(obj):
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
# So that we don't infinitely recurse since frozenset and tuples
|
||||||
# are Sequences.
|
# are Sequences.
|
||||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||||
return obj
|
return obj
|
||||||
elif isinstance(obj, Mapping):
|
elif isinstance(obj, Mapping):
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||||
@ -188,6 +193,9 @@ class BasicCache:
|
|||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -265,6 +273,29 @@ class HierarchicalCache(BasicCache):
|
|||||||
assert cache is not None
|
assert cache is not None
|
||||||
return await cache._ensure_subcache(node_id, children_ids)
|
return await cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
|
class NullCache:
|
||||||
|
|
||||||
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def all_node_ids(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
return self
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class)
|
||||||
@ -318,155 +349,75 @@ class LRUCache(BasicCache):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
class DependencyAwareCache(BasicCache):
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||||
"""
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||||
A cache implementation that tracks dependencies between nodes and manages
|
|
||||||
their execution and caching accordingly. It extends the BasicCache class.
|
RAM_CACHE_HYSTERESIS = 1.1
|
||||||
Nodes are removed from this cache once all of their descendants have been
|
|
||||||
executed.
|
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||||
"""
|
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||||
|
|
||||||
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||||
|
|
||||||
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||||
|
#in constantly changing setups.
|
||||||
|
|
||||||
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||||
|
|
||||||
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class):
|
||||||
"""
|
super().__init__(key_class, 0)
|
||||||
Initialize the DependencyAwareCache.
|
self.timestamps = {}
|
||||||
|
|
||||||
Args:
|
|
||||||
key_class: The class used for generating cache keys.
|
|
||||||
"""
|
|
||||||
super().__init__(key_class)
|
|
||||||
self.descendants = {} # Maps node_id -> set of descendant node_ids
|
|
||||||
self.ancestors = {} # Maps node_id -> set of ancestor node_ids
|
|
||||||
self.executed_nodes = set() # Tracks nodes that have been executed
|
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
"""
|
|
||||||
Clear the entire cache and rebuild the dependency graph.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to initialize the cache for.
|
|
||||||
is_changed_cache: Flag indicating if the cache has changed.
|
|
||||||
"""
|
|
||||||
# Clear all existing cache data
|
|
||||||
self.cache.clear()
|
|
||||||
self.subcaches.clear()
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
self.executed_nodes.clear()
|
|
||||||
|
|
||||||
# Call the parent method to initialize the cache with the new prompt
|
|
||||||
await super().set_prompt(dynprompt, node_ids, is_changed_cache)
|
|
||||||
|
|
||||||
# Rebuild the dependency graph
|
|
||||||
self._build_dependency_graph(dynprompt, node_ids)
|
|
||||||
|
|
||||||
def _build_dependency_graph(self, dynprompt, node_ids):
|
|
||||||
"""
|
|
||||||
Build the dependency graph for all nodes.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dynprompt: The dynamic prompt object containing node information.
|
|
||||||
node_ids: List of node IDs to build the graph for.
|
|
||||||
"""
|
|
||||||
self.descendants.clear()
|
|
||||||
self.ancestors.clear()
|
|
||||||
for node_id in node_ids:
|
|
||||||
self.descendants[node_id] = set()
|
|
||||||
self.ancestors[node_id] = set()
|
|
||||||
|
|
||||||
for node_id in node_ids:
|
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
|
||||||
for input_data in inputs.values():
|
|
||||||
if is_link(input_data): # Check if the input is a link to another node
|
|
||||||
ancestor_id = input_data[0]
|
|
||||||
self.descendants[ancestor_id].add(node_id)
|
|
||||||
self.ancestors[node_id].add(ancestor_id)
|
|
||||||
|
|
||||||
def set(self, node_id, value):
|
|
||||||
"""
|
|
||||||
Mark a node as executed and store its value in the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to store.
|
|
||||||
value: The value to store for the node.
|
|
||||||
"""
|
|
||||||
self._set_immediate(node_id, value)
|
|
||||||
self.executed_nodes.add(node_id)
|
|
||||||
self._cleanup_ancestors(node_id)
|
|
||||||
|
|
||||||
def get(self, node_id):
|
|
||||||
"""
|
|
||||||
Retrieve the cached value for a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The cached value for the node.
|
|
||||||
"""
|
|
||||||
return self._get_immediate(node_id)
|
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
|
||||||
"""
|
|
||||||
Ensure a subcache exists for a node and update dependencies.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the parent node.
|
|
||||||
children_ids: List of child node IDs to associate with the parent node.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The subcache object for the node.
|
|
||||||
"""
|
|
||||||
subcache = await super()._ensure_subcache(node_id, children_ids)
|
|
||||||
for child_id in children_ids:
|
|
||||||
self.descendants[node_id].add(child_id)
|
|
||||||
self.ancestors[child_id].add(node_id)
|
|
||||||
return subcache
|
|
||||||
|
|
||||||
def _cleanup_ancestors(self, node_id):
|
|
||||||
"""
|
|
||||||
Check if ancestors of a node can be removed from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node whose ancestors are to be checked.
|
|
||||||
"""
|
|
||||||
for ancestor_id in self.ancestors.get(node_id, []):
|
|
||||||
if ancestor_id in self.executed_nodes:
|
|
||||||
# Remove ancestor if all its descendants have been executed
|
|
||||||
if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
|
|
||||||
self._remove_node(ancestor_id)
|
|
||||||
|
|
||||||
def _remove_node(self, node_id):
|
|
||||||
"""
|
|
||||||
Remove a node from the cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_id: The ID of the node to remove.
|
|
||||||
"""
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
if cache_key in self.cache:
|
|
||||||
del self.cache[cache_key]
|
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
|
||||||
if subcache_key in self.subcaches:
|
|
||||||
del self.subcaches[subcache_key]
|
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
"""
|
self._clean_subcaches()
|
||||||
Clean up unused nodes. This is a no-op for this cache implementation.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def set(self, node_id, value):
|
||||||
"""
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
Dump the cache and dependency graph for debugging.
|
super().set(node_id, value)
|
||||||
|
|
||||||
Returns:
|
def get(self, node_id):
|
||||||
A list containing the cache state and dependency graph.
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
"""
|
return super().get(node_id)
|
||||||
result = super().recursive_debug_dump()
|
|
||||||
result.append({
|
def poll(self, ram_headroom):
|
||||||
"descendants": self.descendants,
|
def _ram_gb():
|
||||||
"ancestors": self.ancestors,
|
return psutil.virtual_memory().available / (1024**3)
|
||||||
"executed_nodes": list(self.executed_nodes),
|
|
||||||
})
|
if _ram_gb() > ram_headroom:
|
||||||
return result
|
return
|
||||||
|
gc.collect()
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_list = []
|
||||||
|
|
||||||
|
for key, (outputs, _), in self.cache.items():
|
||||||
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
|
def scan_list_for_ram_usage(outputs):
|
||||||
|
nonlocal ram_usage
|
||||||
|
if outputs is None:
|
||||||
|
return
|
||||||
|
for output in outputs:
|
||||||
|
if isinstance(output, list):
|
||||||
|
scan_list_for_ram_usage(output)
|
||||||
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||||
|
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||||
|
#be high value intermediates
|
||||||
|
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||||
|
elif hasattr(output, "get_ram_usage"):
|
||||||
|
ram_usage += output.get_ram_usage()
|
||||||
|
scan_list_for_ram_usage(outputs)
|
||||||
|
|
||||||
|
oom_score *= ram_usage
|
||||||
|
#In the case where we have no information on the node ram usage at all,
|
||||||
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
|
||||||
|
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||||
|
_, _, key = clean_list.pop()
|
||||||
|
del self.cache[key]
|
||||||
|
gc.collect()
|
||||||
|
|||||||
@ -153,8 +153,9 @@ class TopologicalSort:
|
|||||||
continue
|
continue
|
||||||
_, _, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy):
|
||||||
node_ids.append(from_node_id)
|
if not self.is_cached(from_node_id):
|
||||||
|
node_ids.append(from_node_id)
|
||||||
links.append((from_node_id, from_socket, unique_id))
|
links.append((from_node_id, from_socket, unique_id))
|
||||||
|
|
||||||
for link in links:
|
for link in links:
|
||||||
@ -194,10 +195,40 @@ class ExecutionList(TopologicalSort):
|
|||||||
super().__init__(dynprompt)
|
super().__init__(dynprompt)
|
||||||
self.output_cache = output_cache
|
self.output_cache = output_cache
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
self.execution_cache = {}
|
||||||
|
self.execution_cache_listeners = {}
|
||||||
|
|
||||||
def is_cached(self, node_id):
|
def is_cached(self, node_id):
|
||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get(node_id) is not None
|
||||||
|
|
||||||
|
def cache_link(self, from_node_id, to_node_id):
|
||||||
|
if not to_node_id in self.execution_cache:
|
||||||
|
self.execution_cache[to_node_id] = {}
|
||||||
|
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||||
|
if not from_node_id in self.execution_cache_listeners:
|
||||||
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
|
if not to_node_id in self.execution_cache:
|
||||||
|
return None
|
||||||
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
#Write back to the main cache on touch.
|
||||||
|
self.output_cache.set(from_node_id, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def cache_update(self, node_id, value):
|
||||||
|
if node_id in self.execution_cache_listeners:
|
||||||
|
for to_node_id in self.execution_cache_listeners[node_id]:
|
||||||
|
if to_node_id in self.execution_cache:
|
||||||
|
self.execution_cache[to_node_id][node_id] = value
|
||||||
|
|
||||||
|
def add_strong_link(self, from_node_id, from_socket, to_node_id):
|
||||||
|
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||||
|
self.cache_link(from_node_id, to_node_id)
|
||||||
|
|
||||||
async def stage_node_execution(self):
|
async def stage_node_execution(self):
|
||||||
assert self.staged_node_id is None
|
assert self.staged_node_id is None
|
||||||
if self.is_empty():
|
if self.is_empty():
|
||||||
@ -277,6 +308,8 @@ class ExecutionList(TopologicalSort):
|
|||||||
def complete_node_execution(self):
|
def complete_node_execution(self):
|
||||||
node_id = self.staged_node_id
|
node_id = self.staged_node_id
|
||||||
self.pop_node(node_id)
|
self.pop_node(node_id)
|
||||||
|
self.execution_cache.pop(node_id, None)
|
||||||
|
self.execution_cache_listeners.pop(node_id, None)
|
||||||
self.staged_node_id = None
|
self.staged_node_id = None
|
||||||
|
|
||||||
def get_nodes_in_cycle(self):
|
def get_nodes_in_cycle(self):
|
||||||
|
|||||||
@ -142,9 +142,10 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
|||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
output_container.metadata[key] = value
|
output_container.metadata[key] = value
|
||||||
|
|
||||||
|
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
||||||
# Set up the output stream with appropriate properties
|
# Set up the output stream with appropriate properties
|
||||||
if format == "opus":
|
if format == "opus":
|
||||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||||
if quality == "64k":
|
if quality == "64k":
|
||||||
out_stream.bit_rate = 64000
|
out_stream.bit_rate = 64000
|
||||||
elif quality == "96k":
|
elif quality == "96k":
|
||||||
@ -156,7 +157,7 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
elif format == "mp3":
|
elif format == "mp3":
|
||||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||||
if quality == "V0":
|
if quality == "V0":
|
||||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||||
out_stream.codec_context.qscale = 1
|
out_stream.codec_context.qscale = 1
|
||||||
@ -165,9 +166,9 @@ def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=Non
|
|||||||
elif quality == "320k":
|
elif quality == "320k":
|
||||||
out_stream.bit_rate = 320000
|
out_stream.bit_rate = 320000
|
||||||
else: #format == "flac":
|
else: #format == "flac":
|
||||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||||
|
|
||||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo')
|
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
||||||
frame.sample_rate = sample_rate
|
frame.sample_rate = sample_rate
|
||||||
frame.pts = 0
|
frame.pts = 0
|
||||||
output_container.mux(out_stream.encode(frame))
|
output_container.mux(out_stream.encode(frame))
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def resize_mask(mask, shape):
|
def resize_mask(mask, shape):
|
||||||
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
|
||||||
@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
|
|||||||
return out_image, out_alpha
|
return out_image, out_alpha
|
||||||
|
|
||||||
|
|
||||||
class PorterDuffImageComposite:
|
class PorterDuffImageComposite(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="PorterDuffImageComposite",
|
||||||
"source": ("IMAGE",),
|
display_name="Porter-Duff Image Composite",
|
||||||
"source_alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
"destination": ("IMAGE",),
|
inputs=[
|
||||||
"destination_alpha": ("MASK",),
|
io.Image.Input("source"),
|
||||||
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
|
io.Mask.Input("source_alpha"),
|
||||||
},
|
io.Image.Input("destination"),
|
||||||
}
|
io.Mask.Input("destination_alpha"),
|
||||||
|
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
@classmethod
|
||||||
FUNCTION = "composite"
|
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
|
||||||
CATEGORY = "mask/compositing"
|
|
||||||
|
|
||||||
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
|
|
||||||
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
out_alphas = []
|
out_alphas = []
|
||||||
@ -150,45 +157,48 @@ class PorterDuffImageComposite:
|
|||||||
out_images.append(out_image)
|
out_images.append(out_image)
|
||||||
out_alphas.append(out_alpha.squeeze(2))
|
out_alphas.append(out_alpha.squeeze(2))
|
||||||
|
|
||||||
result = (torch.stack(out_images), torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class SplitImageWithAlpha:
|
class SplitImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="SplitImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Split Image with Alpha",
|
||||||
}
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
io.Mask.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "split_image_with_alpha"
|
|
||||||
|
|
||||||
def split_image_with_alpha(self, image: torch.Tensor):
|
|
||||||
out_images = [i[:,:,:3] for i in image]
|
out_images = [i[:,:,:3] for i in image]
|
||||||
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
|
||||||
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class JoinImageWithAlpha:
|
class JoinImageWithAlpha(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {
|
return io.Schema(
|
||||||
"required": {
|
node_id="JoinImageWithAlpha",
|
||||||
"image": ("IMAGE",),
|
display_name="Join Image with Alpha",
|
||||||
"alpha": ("MASK",),
|
category="mask/compositing",
|
||||||
}
|
inputs=[
|
||||||
}
|
io.Image.Input("image"),
|
||||||
|
io.Mask.Input("alpha"),
|
||||||
|
],
|
||||||
|
outputs=[io.Image.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "mask/compositing"
|
@classmethod
|
||||||
RETURN_TYPES = ("IMAGE",)
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||||
FUNCTION = "join_image_with_alpha"
|
|
||||||
|
|
||||||
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
|
|
||||||
batch_size = min(len(image), len(alpha))
|
batch_size = min(len(image), len(alpha))
|
||||||
out_images = []
|
out_images = []
|
||||||
|
|
||||||
@ -196,19 +206,18 @@ class JoinImageWithAlpha:
|
|||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
|
||||||
|
|
||||||
result = (torch.stack(out_images),)
|
return io.NodeOutput(torch.stack(out_images))
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class CompositingExtension(ComfyExtension):
|
||||||
"PorterDuffImageComposite": PorterDuffImageComposite,
|
@override
|
||||||
"SplitImageWithAlpha": SplitImageWithAlpha,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"JoinImageWithAlpha": JoinImageWithAlpha,
|
return [
|
||||||
}
|
PorterDuffImageComposite,
|
||||||
|
SplitImageWithAlpha,
|
||||||
|
JoinImageWithAlpha,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
async def comfy_entrypoint() -> CompositingExtension:
|
||||||
"PorterDuffImageComposite": "Porter-Duff Image Composite",
|
return CompositingExtension()
|
||||||
"SplitImageWithAlpha": "Split Image with Alpha",
|
|
||||||
"JoinImageWithAlpha": "Join Image with Alpha",
|
|
||||||
}
|
|
||||||
|
|||||||
@ -1,20 +1,26 @@
|
|||||||
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
|
||||||
import nodes
|
import nodes
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class SetUnionControlNetType:
|
class SetUnionControlNetType(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
return io.Schema(
|
||||||
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
|
node_id="SetUnionControlNetType",
|
||||||
}}
|
category="conditioning/controlnet",
|
||||||
|
inputs=[
|
||||||
|
io.ControlNet.Input("control_net"),
|
||||||
|
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.ControlNet.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "conditioning/controlnet"
|
@classmethod
|
||||||
RETURN_TYPES = ("CONTROL_NET",)
|
def execute(cls, control_net, type) -> io.NodeOutput:
|
||||||
|
|
||||||
FUNCTION = "set_controlnet_type"
|
|
||||||
|
|
||||||
def set_controlnet_type(self, control_net, type):
|
|
||||||
control_net = control_net.copy()
|
control_net = control_net.copy()
|
||||||
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
type_number = UNION_CONTROLNET_TYPES.get(type, -1)
|
||||||
if type_number >= 0:
|
if type_number >= 0:
|
||||||
@ -22,27 +28,36 @@ class SetUnionControlNetType:
|
|||||||
else:
|
else:
|
||||||
control_net.set_extra_arg("control_type", [])
|
control_net.set_extra_arg("control_type", [])
|
||||||
|
|
||||||
return (control_net,)
|
return io.NodeOutput(control_net)
|
||||||
|
|
||||||
class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
set_controlnet_type = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="ControlNetInpaintingAliMamaApply",
|
||||||
"control_net": ("CONTROL_NET", ),
|
category="conditioning/controlnet",
|
||||||
"vae": ("VAE", ),
|
inputs=[
|
||||||
"image": ("IMAGE", ),
|
io.Conditioning.Input("positive"),
|
||||||
"mask": ("MASK", ),
|
io.Conditioning.Input("negative"),
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
io.ControlNet.Input("control_net"),
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
io.Vae.Input("vae"),
|
||||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
io.Image.Input("image"),
|
||||||
}}
|
io.Mask.Input("mask"),
|
||||||
|
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
|
||||||
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "apply_inpaint_controlnet"
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent) -> io.NodeOutput:
|
||||||
CATEGORY = "conditioning/controlnet"
|
|
||||||
|
|
||||||
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
|
|
||||||
extra_concat = []
|
extra_concat = []
|
||||||
if control_net.concat_mask:
|
if control_net.concat_mask:
|
||||||
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
|
||||||
@ -50,11 +65,20 @@ class ControlNetInpaintingAliMamaApply(nodes.ControlNetApplyAdvanced):
|
|||||||
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
|
||||||
extra_concat = [mask]
|
extra_concat = [mask]
|
||||||
|
|
||||||
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
result = nodes.ControlNetApplyAdvanced().apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
|
||||||
|
return io.NodeOutput(result[0], result[1])
|
||||||
|
|
||||||
|
apply_inpaint_controlnet = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SetUnionControlNetType,
|
||||||
|
ControlNetInpaintingAliMamaApply,
|
||||||
|
]
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"SetUnionControlNetType": SetUnionControlNetType,
|
async def comfy_entrypoint() -> ControlNetExtension:
|
||||||
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
|
return ControlNetExtension()
|
||||||
}
|
|
||||||
|
|||||||
@ -244,6 +244,8 @@ class EasyCacheHolder:
|
|||||||
self.total_steps_skipped += 1
|
self.total_steps_skipped += 1
|
||||||
batch_offset = x.shape[0] // len(uuids)
|
batch_offset = x.shape[0] // len(uuids)
|
||||||
for i, uuid in enumerate(uuids):
|
for i, uuid in enumerate(uuids):
|
||||||
|
# slice out only what is relevant to this cond
|
||||||
|
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
|
||||||
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
|
||||||
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
|
||||||
if not self.allow_mismatch:
|
if not self.allow_mismatch:
|
||||||
@ -261,9 +263,8 @@ class EasyCacheHolder:
|
|||||||
slicing.append(slice(None, dim_u))
|
slicing.append(slice(None, dim_u))
|
||||||
else:
|
else:
|
||||||
slicing.append(slice(None))
|
slicing.append(slice(None))
|
||||||
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
|
batch_slice = batch_slice + slicing
|
||||||
x = x[slicing]
|
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
|
||||||
x += self.uuid_cache_diffs[uuid].to(x.device)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy.k_diffusion.sampling import sigma_to_half_log_snr
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
@ -63,12 +65,105 @@ class EpsilonScaling(io.ComfyNode):
|
|||||||
return io.NodeOutput(model_clone)
|
return io.NodeOutput(model_clone)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_tsr_rescaling_factor(
|
||||||
|
snr: torch.Tensor, tsr_k: float, tsr_variance: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute the rescaling score ratio in Temporal Score Rescaling.
|
||||||
|
|
||||||
|
See equation (6) in https://arxiv.org/pdf/2510.01184v1.
|
||||||
|
"""
|
||||||
|
posinf_mask = torch.isposinf(snr)
|
||||||
|
rescaling_factor = (snr * tsr_variance + 1) / (snr * tsr_variance / tsr_k + 1)
|
||||||
|
return torch.where(posinf_mask, tsr_k, rescaling_factor) # when snr → inf, r = tsr_k
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalScoreRescaling(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="TemporalScoreRescaling",
|
||||||
|
display_name="TSR - Temporal Score Rescaling",
|
||||||
|
category="model_patches/unet",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input(
|
||||||
|
"tsr_k",
|
||||||
|
tooltip=(
|
||||||
|
"Controls the rescaling strength.\n"
|
||||||
|
"Lower k produces more detailed results; higher k produces smoother results in image generation. Setting k = 1 disables rescaling."
|
||||||
|
),
|
||||||
|
default=0.95,
|
||||||
|
min=0.01,
|
||||||
|
max=100.0,
|
||||||
|
step=0.001,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
),
|
||||||
|
io.Float.Input(
|
||||||
|
"tsr_sigma",
|
||||||
|
tooltip=(
|
||||||
|
"Controls how early rescaling takes effect.\n"
|
||||||
|
"Larger values take effect earlier."
|
||||||
|
),
|
||||||
|
default=1.0,
|
||||||
|
min=0.01,
|
||||||
|
max=100.0,
|
||||||
|
step=0.001,
|
||||||
|
display_mode=io.NumberDisplay.number,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(
|
||||||
|
display_name="patched_model",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
description=(
|
||||||
|
"[Post-CFG Function]\n"
|
||||||
|
"TSR - Temporal Score Rescaling (2510.01184)\n\n"
|
||||||
|
"Rescaling the model's score or noise to steer the sampling diversity.\n"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, tsr_k, tsr_sigma) -> io.NodeOutput:
|
||||||
|
tsr_variance = tsr_sigma**2
|
||||||
|
|
||||||
|
def temporal_score_rescaling(args):
|
||||||
|
denoised = args["denoised"]
|
||||||
|
x = args["input"]
|
||||||
|
sigma = args["sigma"]
|
||||||
|
curr_model = args["model"]
|
||||||
|
|
||||||
|
# No rescaling (r = 1) or no noise
|
||||||
|
if tsr_k == 1 or sigma == 0:
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
model_sampling = curr_model.current_patcher.get_model_object("model_sampling")
|
||||||
|
half_log_snr = sigma_to_half_log_snr(sigma, model_sampling)
|
||||||
|
snr = (2 * half_log_snr).exp()
|
||||||
|
|
||||||
|
# No rescaling needed (r = 1)
|
||||||
|
if snr == 0:
|
||||||
|
return denoised
|
||||||
|
|
||||||
|
rescaling_r = compute_tsr_rescaling_factor(snr, tsr_k, tsr_variance)
|
||||||
|
|
||||||
|
# Derived from scaled_denoised = (x - r * sigma * noise) / alpha
|
||||||
|
alpha = sigma * half_log_snr.exp()
|
||||||
|
return torch.lerp(x / alpha, denoised, rescaling_r)
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_sampler_post_cfg_function(temporal_score_rescaling)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
class EpsilonScalingExtension(ComfyExtension):
|
class EpsilonScalingExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EpsilonScaling,
|
EpsilonScaling,
|
||||||
|
TemporalScoreRescaling,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> EpsilonScalingExtension:
|
async def comfy_entrypoint() -> EpsilonScalingExtension:
|
||||||
return EpsilonScalingExtension()
|
return EpsilonScalingExtension()
|
||||||
|
|||||||
@ -1,60 +1,80 @@
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class CLIPTextEncodeFlux:
|
|
||||||
|
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeFlux",
|
||||||
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning/flux",
|
||||||
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.String.Input("clip_l", multiline=True, dynamic_prompts=True),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("t5xxl", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_l, t5xxl, guidance) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_l, t5xxl, guidance):
|
|
||||||
tokens = clip.tokenize(clip_l)
|
tokens = clip.tokenize(clip_l)
|
||||||
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
|
||||||
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens, add_dict={"guidance": guidance}))
|
||||||
|
|
||||||
class FluxGuidance:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class FluxGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxGuidance",
|
||||||
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Float.Input("guidance", default=3.5, min=0.0, max=100.0, step=0.1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, guidance) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, guidance):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxDisableGuidance:
|
class FluxDisableGuidance(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxDisableGuidance",
|
||||||
}}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node completely disables the guidance embed on Flux and Flux like models",
|
||||||
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
|
|
||||||
|
|
||||||
def append(self, conditioning):
|
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERED_KONTEXT_RESOLUTIONS = [
|
||||||
@ -78,52 +98,73 @@ PREFERED_KONTEXT_RESOLUTIONS = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextImageScale:
|
class FluxKontextImageScale(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"image": ("IMAGE", ),
|
return io.Schema(
|
||||||
},
|
node_id="FluxKontextImageScale",
|
||||||
}
|
category="advanced/conditioning/flux",
|
||||||
|
description="This node resizes the image to one that is more optimal for flux kontext.",
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("image"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
@classmethod
|
||||||
FUNCTION = "scale"
|
def execute(cls, image) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
DESCRIPTION = "This node resizes the image to one that is more optimal for flux kontext."
|
|
||||||
|
|
||||||
def scale(self, image):
|
|
||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
||||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
return (image, )
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
scale = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class FluxKontextMultiReferenceLatentMethod:
|
class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"conditioning": ("CONDITIONING", ),
|
node_id="FluxKontextMultiReferenceLatentMethod",
|
||||||
"reference_latents_method": (("offset", "index", "uxo/uno"), ),
|
category="advanced/conditioning/flux",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Conditioning.Input("conditioning"),
|
||||||
|
io.Combo.Input(
|
||||||
|
"reference_latents_method",
|
||||||
|
options=["offset", "index", "uxo/uno"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
@classmethod
|
||||||
FUNCTION = "append"
|
def execute(cls, conditioning, reference_latents_method) -> io.NodeOutput:
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning/flux"
|
|
||||||
|
|
||||||
def append(self, conditioning, reference_latents_method):
|
|
||||||
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
if "uxo" in reference_latents_method or "uso" in reference_latents_method:
|
||||||
reference_latents_method = "uxo"
|
reference_latents_method = "uxo"
|
||||||
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method})
|
||||||
return (c, )
|
return io.NodeOutput(c)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
append = execute # TODO: remove
|
||||||
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
|
|
||||||
"FluxGuidance": FluxGuidance,
|
|
||||||
"FluxDisableGuidance": FluxDisableGuidance,
|
class FluxExtension(ComfyExtension):
|
||||||
"FluxKontextImageScale": FluxKontextImageScale,
|
@override
|
||||||
"FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
}
|
return [
|
||||||
|
CLIPTextEncodeFlux,
|
||||||
|
FluxGuidance,
|
||||||
|
FluxDisableGuidance,
|
||||||
|
FluxKontextImageScale,
|
||||||
|
FluxKontextMultiReferenceLatentMethod,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> FluxExtension:
|
||||||
|
return FluxExtension()
|
||||||
|
|||||||
@ -2,42 +2,60 @@ import nodes
|
|||||||
import node_helpers
|
import node_helpers
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
class CLIPTextEncodeHunyuanDiT:
|
class CLIPTextEncodeHunyuanDiT(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="CLIPTextEncodeHunyuanDiT",
|
||||||
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
category="advanced/conditioning",
|
||||||
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
}}
|
io.Clip.Input("clip"),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("bert", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.String.Input("mt5xl", multiline=True, dynamic_prompts=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, bert, mt5xl) -> io.NodeOutput:
|
||||||
def encode(self, clip, bert, mt5xl):
|
|
||||||
tokens = clip.tokenize(bert)
|
tokens = clip.tokenize(bert)
|
||||||
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
|
||||||
|
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
class EmptyHunyuanLatentVideo:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
return io.Schema(
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
node_id="EmptyHunyuanLatentVideo",
|
||||||
"length": ("INT", {"default": 25, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
category="latent/video",
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
inputs=[
|
||||||
RETURN_TYPES = ("LATENT",)
|
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
FUNCTION = "generate"
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=25, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/video"
|
@classmethod
|
||||||
|
def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput:
|
||||||
def generate(self, width, height, length, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent}, )
|
return io.NodeOutput({"samples":latent})
|
||||||
|
|
||||||
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||||
@ -50,45 +68,61 @@ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
|||||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
class TextEncodeHunyuanVideo_ImageToVideo:
|
class TextEncodeHunyuanVideo_ImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"clip": ("CLIP", ),
|
node_id="TextEncodeHunyuanVideo_ImageToVideo",
|
||||||
"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
|
category="advanced/conditioning",
|
||||||
"prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
|
inputs=[
|
||||||
"image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
|
io.Clip.Input("clip"),
|
||||||
}}
|
io.ClipVisionOutput.Input("clip_vision_output"),
|
||||||
RETURN_TYPES = ("CONDITIONING",)
|
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
|
||||||
FUNCTION = "encode"
|
io.Int.Input(
|
||||||
|
"image_interleave",
|
||||||
|
default=2,
|
||||||
|
min=1,
|
||||||
|
max=512,
|
||||||
|
tooltip="How much the image influences things vs the text prompt. Higher number means more influence from the text prompt.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "advanced/conditioning"
|
@classmethod
|
||||||
|
def execute(cls, clip, clip_vision_output, prompt, image_interleave) -> io.NodeOutput:
|
||||||
def encode(self, clip, clip_vision_output, prompt, image_interleave):
|
|
||||||
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
|
||||||
return (clip.encode_from_tokens_scheduled(tokens), )
|
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens))
|
||||||
|
|
||||||
class HunyuanImageToVideo:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanImageToVideo(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"vae": ("VAE", ),
|
node_id="HunyuanImageToVideo",
|
||||||
"width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
category="conditioning/video_models",
|
||||||
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
|
inputs=[
|
||||||
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
|
io.Conditioning.Input("positive"),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
|
io.Vae.Input("vae"),
|
||||||
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
|
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
},
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
"optional": {"start_image": ("IMAGE", ),
|
io.Int.Input("length", default=53, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
}}
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
io.Combo.Input("guidance_type", options=["v1 (concat)", "v2 (replace)", "custom"]),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "LATENT")
|
@classmethod
|
||||||
RETURN_NAMES = ("positive", "latent")
|
def execute(cls, positive, vae, width, height, length, batch_size, guidance_type, start_image=None) -> io.NodeOutput:
|
||||||
FUNCTION = "encode"
|
|
||||||
|
|
||||||
CATEGORY = "conditioning/video_models"
|
|
||||||
|
|
||||||
def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
|
|
||||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
|
|
||||||
@ -111,51 +145,76 @@ class HunyuanImageToVideo:
|
|||||||
positive = node_helpers.conditioning_set_values(positive, cond)
|
positive = node_helpers.conditioning_set_values(positive, cond)
|
||||||
|
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
return (positive, out_latent)
|
return io.NodeOutput(positive, out_latent)
|
||||||
|
|
||||||
class EmptyHunyuanImageLatent:
|
encode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyHunyuanImageLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "width": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
return io.Schema(
|
||||||
"height": ("INT", {"default": 2048, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
node_id="EmptyHunyuanImageLatent",
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
category="latent",
|
||||||
RETURN_TYPES = ("LATENT",)
|
inputs=[
|
||||||
FUNCTION = "generate"
|
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent"
|
@classmethod
|
||||||
|
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
|
||||||
def generate(self, width, height, batch_size=1):
|
|
||||||
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
latent = torch.zeros([batch_size, 64, height // 32, width // 32], device=comfy.model_management.intermediate_device())
|
||||||
return ({"samples":latent}, )
|
return io.NodeOutput({"samples":latent})
|
||||||
|
|
||||||
class HunyuanRefinerLatent:
|
generate = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
|
class HunyuanRefinerLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"positive": ("CONDITIONING", ),
|
return io.Schema(
|
||||||
"negative": ("CONDITIONING", ),
|
node_id="HunyuanRefinerLatent",
|
||||||
"latent": ("LATENT", ),
|
inputs=[
|
||||||
"noise_augmentation": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01}),
|
io.Conditioning.Input("positive"),
|
||||||
}}
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Latent.Input("latent"),
|
||||||
|
io.Float.Input("noise_augmentation", default=0.10, min=0.0, max=1.0, step=0.01),
|
||||||
|
|
||||||
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
|
],
|
||||||
RETURN_NAMES = ("positive", "negative", "latent")
|
outputs=[
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
FUNCTION = "execute"
|
@classmethod
|
||||||
|
def execute(cls, positive, negative, latent, noise_augmentation) -> io.NodeOutput:
|
||||||
def execute(self, positive, negative, latent, noise_augmentation):
|
|
||||||
latent = latent["samples"]
|
latent = latent["samples"]
|
||||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
||||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": latent, "noise_augmentation": noise_augmentation})
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
out_latent["samples"] = torch.zeros([latent.shape[0], 32, latent.shape[-3], latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||||
return (positive, negative, out_latent)
|
return io.NodeOutput(positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
class HunyuanExtension(ComfyExtension):
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
@override
|
||||||
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
|
return [
|
||||||
"HunyuanImageToVideo": HunyuanImageToVideo,
|
CLIPTextEncodeHunyuanDiT,
|
||||||
"EmptyHunyuanImageLatent": EmptyHunyuanImageLatent,
|
TextEncodeHunyuanVideo_ImageToVideo,
|
||||||
"HunyuanRefinerLatent": HunyuanRefinerLatent,
|
EmptyHunyuanLatentVideo,
|
||||||
}
|
HunyuanImageToVideo,
|
||||||
|
EmptyHunyuanImageLatent,
|
||||||
|
HunyuanRefinerLatent,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HunyuanExtension:
|
||||||
|
return HunyuanExtension()
|
||||||
|
|||||||
@ -2,6 +2,9 @@ import comfy.utils
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def load_hypernetwork_patch(path, strength):
|
def load_hypernetwork_patch(path, strength):
|
||||||
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
sd = comfy.utils.load_torch_file(path, safe_load=True)
|
||||||
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
|
|||||||
|
|
||||||
return hypernetwork_patch(out, strength)
|
return hypernetwork_patch(out, strength)
|
||||||
|
|
||||||
class HypernetworkLoader:
|
class HypernetworkLoader(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return IO.Schema(
|
||||||
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ),
|
node_id="HypernetworkLoader",
|
||||||
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
category="loaders",
|
||||||
}}
|
inputs=[
|
||||||
RETURN_TYPES = ("MODEL",)
|
IO.Model.Input("model"),
|
||||||
FUNCTION = "load_hypernetwork"
|
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||||
|
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "loaders"
|
@classmethod
|
||||||
|
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
|
||||||
def load_hypernetwork(self, model, hypernetwork_name, strength):
|
|
||||||
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
|
||||||
model_hypernetwork = model.clone()
|
model_hypernetwork = model.clone()
|
||||||
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
patch = load_hypernetwork_patch(hypernetwork_path, strength)
|
||||||
if patch is not None:
|
if patch is not None:
|
||||||
model_hypernetwork.set_model_attn1_patch(patch)
|
model_hypernetwork.set_model_attn1_patch(patch)
|
||||||
model_hypernetwork.set_model_attn2_patch(patch)
|
model_hypernetwork.set_model_attn2_patch(patch)
|
||||||
return (model_hypernetwork,)
|
return IO.NodeOutput(model_hypernetwork)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
load_hypernetwork = execute # TODO: remove
|
||||||
"HypernetworkLoader": HypernetworkLoader
|
|
||||||
}
|
|
||||||
|
class HyperNetworkExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
HypernetworkLoader,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> HyperNetworkExtension:
|
||||||
|
return HyperNetworkExtension()
|
||||||
|
|||||||
@ -2,6 +2,8 @@ import comfy.utils
|
|||||||
import comfy_extras.nodes_post_processing
|
import comfy_extras.nodes_post_processing
|
||||||
import torch
|
import torch
|
||||||
import nodes
|
import nodes
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
||||||
@ -13,17 +15,23 @@ def reshape_latent_to(target_shape, latent, repeat_batch=True):
|
|||||||
return latent
|
return latent
|
||||||
|
|
||||||
|
|
||||||
class LatentAdd:
|
class LatentAdd(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentAdd",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -31,19 +39,25 @@ class LatentAdd:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 + s2
|
samples_out["samples"] = s1 + s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentSubtract:
|
class LatentSubtract(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentSubtract",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -51,41 +65,49 @@ class LatentSubtract:
|
|||||||
|
|
||||||
s2 = reshape_latent_to(s1.shape, s2)
|
s2 = reshape_latent_to(s1.shape, s2)
|
||||||
samples_out["samples"] = s1 - s2
|
samples_out["samples"] = s1 - s2
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentMultiply:
|
class LatentMultiply(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}),
|
node_id="LatentMultiply",
|
||||||
}}
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, multiplier):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = s1 * multiplier
|
samples_out["samples"] = s1 * multiplier
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentInterpolate:
|
class LatentInterpolate(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",),
|
return io.Schema(
|
||||||
"samples2": ("LATENT",),
|
node_id="LatentInterpolate",
|
||||||
"ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
category="latent/advanced",
|
||||||
}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Float.Input("ratio", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, ratio) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, ratio):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -104,19 +126,26 @@ class LatentInterpolate:
|
|||||||
st = torch.nan_to_num(t / mt)
|
st = torch.nan_to_num(t / mt)
|
||||||
|
|
||||||
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
samples_out["samples"] = st * (m1 * ratio + m2 * (1.0 - ratio))
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentConcat:
|
class LatentConcat(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",), "dim": (["x", "-x", "y", "-y", "t", "-t"], )}}
|
return io.Schema(
|
||||||
|
node_id="LatentConcat",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
io.Combo.Input("dim", options=["x", "-x", "y", "-y", "t", "-t"]),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples1, samples2, dim) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples1, samples2, dim):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
|
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
@ -136,22 +165,27 @@ class LatentConcat:
|
|||||||
dim = -3
|
dim = -3
|
||||||
|
|
||||||
samples_out["samples"] = torch.cat(c, dim=dim)
|
samples_out["samples"] = torch.cat(c, dim=dim)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentCut:
|
class LatentCut(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {"samples": ("LATENT",),
|
return io.Schema(
|
||||||
"dim": (["x", "y", "t"], ),
|
node_id="LatentCut",
|
||||||
"index": ("INT", {"default": 0, "min": -nodes.MAX_RESOLUTION, "max": nodes.MAX_RESOLUTION, "step": 1}),
|
category="latent/advanced",
|
||||||
"amount": ("INT", {"default": 1, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 1})}}
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("dim", options=["x", "y", "t"]),
|
||||||
|
io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
io.Int.Input("amount", default=1, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, dim, index, amount) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, dim, index, amount):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
@ -171,19 +205,25 @@ class LatentCut:
|
|||||||
amount = min(-index, amount)
|
amount = min(-index, amount)
|
||||||
|
|
||||||
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
samples_out["samples"] = torch.narrow(s1, dim, index, amount)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatch:
|
class LatentBatch(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples1": ("LATENT",), "samples2": ("LATENT",)}}
|
return io.Schema(
|
||||||
|
node_id="LatentBatch",
|
||||||
|
category="latent/batch",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples1"),
|
||||||
|
io.Latent.Input("samples2"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "batch"
|
def execute(cls, samples1, samples2) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/batch"
|
|
||||||
|
|
||||||
def batch(self, samples1, samples2):
|
|
||||||
samples_out = samples1.copy()
|
samples_out = samples1.copy()
|
||||||
s1 = samples1["samples"]
|
s1 = samples1["samples"]
|
||||||
s2 = samples2["samples"]
|
s2 = samples2["samples"]
|
||||||
@ -192,20 +232,25 @@ class LatentBatch:
|
|||||||
s = torch.cat((s1, s2), dim=0)
|
s = torch.cat((s1, s2), dim=0)
|
||||||
samples_out["samples"] = s
|
samples_out["samples"] = s
|
||||||
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
samples_out["batch_index"] = samples1.get("batch_index", [x for x in range(0, s1.shape[0])]) + samples2.get("batch_index", [x for x in range(0, s2.shape[0])])
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentBatchSeedBehavior:
|
class LatentBatchSeedBehavior(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"seed_behavior": (["random", "fixed"],{"default": "fixed"}),}}
|
node_id="LatentBatchSeedBehavior",
|
||||||
|
category="latent/advanced",
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, seed_behavior) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced"
|
|
||||||
|
|
||||||
def op(self, samples, seed_behavior):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
latent = samples["samples"]
|
latent = samples["samples"]
|
||||||
if seed_behavior == "random":
|
if seed_behavior == "random":
|
||||||
@ -215,41 +260,50 @@ class LatentBatchSeedBehavior:
|
|||||||
batch_number = samples_out.get("batch_index", [0])[0]
|
batch_number = samples_out.get("batch_index", [0])[0]
|
||||||
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
samples_out["batch_index"] = [batch_number] * latent.shape[0]
|
||||||
|
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperation:
|
class LatentApplyOperation(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "samples": ("LATENT",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperation",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Latent.Input("samples"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, samples, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, samples, operation):
|
|
||||||
samples_out = samples.copy()
|
samples_out = samples.copy()
|
||||||
|
|
||||||
s1 = samples["samples"]
|
s1 = samples["samples"]
|
||||||
samples_out["samples"] = operation(latent=s1)
|
samples_out["samples"] = operation(latent=s1)
|
||||||
return (samples_out,)
|
return io.NodeOutput(samples_out)
|
||||||
|
|
||||||
class LatentApplyOperationCFG:
|
class LatentApplyOperationCFG(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"operation": ("LATENT_OPERATION",),
|
node_id="LatentApplyOperationCFG",
|
||||||
}}
|
category="latent/advanced/operations",
|
||||||
RETURN_TYPES = ("MODEL",)
|
is_experimental=True,
|
||||||
FUNCTION = "patch"
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.LatentOperation.Input("operation"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
@classmethod
|
||||||
EXPERIMENTAL = True
|
def execute(cls, model, operation) -> io.NodeOutput:
|
||||||
|
|
||||||
def patch(self, model, operation):
|
|
||||||
m = model.clone()
|
m = model.clone()
|
||||||
|
|
||||||
def pre_cfg_function(args):
|
def pre_cfg_function(args):
|
||||||
@ -261,21 +315,25 @@ class LatentApplyOperationCFG:
|
|||||||
return conds_out
|
return conds_out
|
||||||
|
|
||||||
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
class LatentOperationTonemapReinhard:
|
class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01}),
|
return io.Schema(
|
||||||
}}
|
node_id="LatentOperationTonemapReinhard",
|
||||||
|
category="latent/advanced/operations",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.LatentOperation.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, multiplier) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, multiplier):
|
|
||||||
def tonemap_reinhard(latent, **kwargs):
|
def tonemap_reinhard(latent, **kwargs):
|
||||||
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
latent_vector_magnitude = (torch.linalg.vector_norm(latent, dim=(1)) + 0.0000000001)[:,None]
|
||||||
normalized_latent = latent / latent_vector_magnitude
|
normalized_latent = latent / latent_vector_magnitude
|
||||||
@ -291,39 +349,27 @@ class LatentOperationTonemapReinhard:
|
|||||||
new_magnitude *= top
|
new_magnitude *= top
|
||||||
|
|
||||||
return normalized_latent * new_magnitude
|
return normalized_latent * new_magnitude
|
||||||
return (tonemap_reinhard,)
|
return io.NodeOutput(tonemap_reinhard)
|
||||||
|
|
||||||
class LatentOperationSharpen:
|
class LatentOperationSharpen(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": {
|
return io.Schema(
|
||||||
"sharpen_radius": ("INT", {
|
node_id="LatentOperationSharpen",
|
||||||
"default": 9,
|
category="latent/advanced/operations",
|
||||||
"min": 1,
|
is_experimental=True,
|
||||||
"max": 31,
|
inputs=[
|
||||||
"step": 1
|
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1),
|
||||||
}),
|
io.Float.Input("sigma", default=1.0, min=0.1, max=10.0, step=0.1),
|
||||||
"sigma": ("FLOAT", {
|
io.Float.Input("alpha", default=0.1, min=0.0, max=5.0, step=0.01),
|
||||||
"default": 1.0,
|
],
|
||||||
"min": 0.1,
|
outputs=[
|
||||||
"max": 10.0,
|
io.LatentOperation.Output(),
|
||||||
"step": 0.1
|
],
|
||||||
}),
|
)
|
||||||
"alpha": ("FLOAT", {
|
|
||||||
"default": 0.1,
|
|
||||||
"min": 0.0,
|
|
||||||
"max": 5.0,
|
|
||||||
"step": 0.01
|
|
||||||
}),
|
|
||||||
}}
|
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT_OPERATION",)
|
@classmethod
|
||||||
FUNCTION = "op"
|
def execute(cls, sharpen_radius, sigma, alpha) -> io.NodeOutput:
|
||||||
|
|
||||||
CATEGORY = "latent/advanced/operations"
|
|
||||||
EXPERIMENTAL = True
|
|
||||||
|
|
||||||
def op(self, sharpen_radius, sigma, alpha):
|
|
||||||
def sharpen(latent, **kwargs):
|
def sharpen(latent, **kwargs):
|
||||||
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
luminance = (torch.linalg.vector_norm(latent, dim=(1)) + 1e-6)[:,None]
|
||||||
normalized_latent = latent / luminance
|
normalized_latent = latent / luminance
|
||||||
@ -340,19 +386,27 @@ class LatentOperationSharpen:
|
|||||||
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
sharpened = torch.nn.functional.conv2d(padded_image, kernel.repeat(channels, 1, 1).unsqueeze(1), padding=kernel_size // 2, groups=channels)[:,:,sharpen_radius:-sharpen_radius, sharpen_radius:-sharpen_radius]
|
||||||
|
|
||||||
return luminance * sharpened
|
return luminance * sharpened
|
||||||
return (sharpen,)
|
return io.NodeOutput(sharpen)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LatentAdd": LatentAdd,
|
class LatentExtension(ComfyExtension):
|
||||||
"LatentSubtract": LatentSubtract,
|
@override
|
||||||
"LatentMultiply": LatentMultiply,
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
"LatentInterpolate": LatentInterpolate,
|
return [
|
||||||
"LatentConcat": LatentConcat,
|
LatentAdd,
|
||||||
"LatentCut": LatentCut,
|
LatentSubtract,
|
||||||
"LatentBatch": LatentBatch,
|
LatentMultiply,
|
||||||
"LatentBatchSeedBehavior": LatentBatchSeedBehavior,
|
LatentInterpolate,
|
||||||
"LatentApplyOperation": LatentApplyOperation,
|
LatentConcat,
|
||||||
"LatentApplyOperationCFG": LatentApplyOperationCFG,
|
LatentCut,
|
||||||
"LatentOperationTonemapReinhard": LatentOperationTonemapReinhard,
|
LatentBatch,
|
||||||
"LatentOperationSharpen": LatentOperationSharpen,
|
LatentBatchSeedBehavior,
|
||||||
}
|
LatentApplyOperation,
|
||||||
|
LatentApplyOperationCFG,
|
||||||
|
LatentOperationTonemapReinhard,
|
||||||
|
LatentOperationSharpen,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LatentExtension:
|
||||||
|
return LatentExtension()
|
||||||
|
|||||||
@ -5,6 +5,8 @@ import folder_paths
|
|||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
CLAMP_QUANTILE = 0.99
|
CLAMP_QUANTILE = 0.99
|
||||||
|
|
||||||
@ -71,32 +73,40 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
|
|||||||
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
|
||||||
return output_sd
|
return output_sd
|
||||||
|
|
||||||
class LoraSave:
|
class LoraSave(io.ComfyNode):
|
||||||
def __init__(self):
|
@classmethod
|
||||||
self.output_dir = folder_paths.get_output_directory()
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="LoraSave",
|
||||||
|
display_name="Extract and Save Lora",
|
||||||
|
category="_for_testing",
|
||||||
|
inputs=[
|
||||||
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1),
|
||||||
|
io.Combo.Input("lora_type", options=tuple(LORA_TYPES.keys())),
|
||||||
|
io.Boolean.Input("bias_diff", default=True),
|
||||||
|
io.Model.Input(
|
||||||
|
"model_diff",
|
||||||
|
tooltip="The ModelSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
io.Clip.Input(
|
||||||
|
"text_encoder_diff",
|
||||||
|
tooltip="The CLIPSubtract output to be converted to a lora.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
is_experimental=True,
|
||||||
|
is_output_node=True,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def execute(cls, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None) -> io.NodeOutput:
|
||||||
return {"required": {"filename_prefix": ("STRING", {"default": "loras/ComfyUI_extracted_lora"}),
|
|
||||||
"rank": ("INT", {"default": 8, "min": 1, "max": 4096, "step": 1}),
|
|
||||||
"lora_type": (tuple(LORA_TYPES.keys()),),
|
|
||||||
"bias_diff": ("BOOLEAN", {"default": True}),
|
|
||||||
},
|
|
||||||
"optional": {"model_diff": ("MODEL", {"tooltip": "The ModelSubtract output to be converted to a lora."}),
|
|
||||||
"text_encoder_diff": ("CLIP", {"tooltip": "The CLIPSubtract output to be converted to a lora."})},
|
|
||||||
}
|
|
||||||
RETURN_TYPES = ()
|
|
||||||
FUNCTION = "save"
|
|
||||||
OUTPUT_NODE = True
|
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
|
||||||
|
|
||||||
def save(self, filename_prefix, rank, lora_type, bias_diff, model_diff=None, text_encoder_diff=None):
|
|
||||||
if model_diff is None and text_encoder_diff is None:
|
if model_diff is None and text_encoder_diff is None:
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
lora_type = LORA_TYPES.get(lora_type)
|
lora_type = LORA_TYPES.get(lora_type)
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory())
|
||||||
|
|
||||||
output_sd = {}
|
output_sd = {}
|
||||||
if model_diff is not None:
|
if model_diff is not None:
|
||||||
@ -108,12 +118,16 @@ class LoraSave:
|
|||||||
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
output_checkpoint = os.path.join(full_output_folder, output_checkpoint)
|
||||||
|
|
||||||
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
comfy.utils.save_torch_file(output_sd, output_checkpoint, metadata=None)
|
||||||
return {}
|
return io.NodeOutput()
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"LoraSave": LoraSave
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
class LoraSaveExtension(ComfyExtension):
|
||||||
"LoraSave": "Extract and Save Lora"
|
@override
|
||||||
}
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoraSave,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> LoraSaveExtension:
|
||||||
|
return LoraSaveExtension()
|
||||||
|
|||||||
@ -1,24 +1,33 @@
|
|||||||
|
from typing_extensions import override
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
class PatchModelAddDownscale:
|
|
||||||
upscale_methods = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
class PatchModelAddDownscale(io.ComfyNode):
|
||||||
|
UPSCALE_METHODS = ["bicubic", "nearest-exact", "bilinear", "area", "bislerp"]
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def define_schema(cls):
|
||||||
return {"required": { "model": ("MODEL",),
|
return io.Schema(
|
||||||
"block_number": ("INT", {"default": 3, "min": 1, "max": 32, "step": 1}),
|
node_id="PatchModelAddDownscale",
|
||||||
"downscale_factor": ("FLOAT", {"default": 2.0, "min": 0.1, "max": 9.0, "step": 0.001}),
|
display_name="PatchModelAddDownscale (Kohya Deep Shrink)",
|
||||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
category="model_patches/unet",
|
||||||
"end_percent": ("FLOAT", {"default": 0.35, "min": 0.0, "max": 1.0, "step": 0.001}),
|
inputs=[
|
||||||
"downscale_after_skip": ("BOOLEAN", {"default": True}),
|
io.Model.Input("model"),
|
||||||
"downscale_method": (s.upscale_methods,),
|
io.Int.Input("block_number", default=3, min=1, max=32, step=1),
|
||||||
"upscale_method": (s.upscale_methods,),
|
io.Float.Input("downscale_factor", default=2.0, min=0.1, max=9.0, step=0.001),
|
||||||
}}
|
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001),
|
||||||
RETURN_TYPES = ("MODEL",)
|
io.Float.Input("end_percent", default=0.35, min=0.0, max=1.0, step=0.001),
|
||||||
FUNCTION = "patch"
|
io.Boolean.Input("downscale_after_skip", default=True),
|
||||||
|
io.Combo.Input("downscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
io.Combo.Input("upscale_method", options=cls.UPSCALE_METHODS),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
CATEGORY = "model_patches/unet"
|
@classmethod
|
||||||
|
def execute(cls, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method) -> io.NodeOutput:
|
||||||
def patch(self, model, block_number, downscale_factor, start_percent, end_percent, downscale_after_skip, downscale_method, upscale_method):
|
|
||||||
model_sampling = model.get_model_object("model_sampling")
|
model_sampling = model.get_model_object("model_sampling")
|
||||||
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
sigma_start = model_sampling.percent_to_sigma(start_percent)
|
||||||
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
sigma_end = model_sampling.percent_to_sigma(end_percent)
|
||||||
@ -41,13 +50,21 @@ class PatchModelAddDownscale:
|
|||||||
else:
|
else:
|
||||||
m.set_model_input_block_patch(input_block_patch)
|
m.set_model_input_block_patch(input_block_patch)
|
||||||
m.set_model_output_block_patch(output_block_patch)
|
m.set_model_output_block_patch(output_block_patch)
|
||||||
return (m, )
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
|
||||||
"PatchModelAddDownscale": PatchModelAddDownscale,
|
|
||||||
}
|
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
# Sampling
|
# Sampling
|
||||||
"PatchModelAddDownscale": "PatchModelAddDownscale (Kohya Deep Shrink)",
|
"PatchModelAddDownscale": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class ModelDownscaleExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
PatchModelAddDownscale,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ModelDownscaleExtension:
|
||||||
|
return ModelDownscaleExtension()
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class PreviewAny():
|
|||||||
value = str(source)
|
value = str(source)
|
||||||
elif source is not None:
|
elif source is not None:
|
||||||
try:
|
try:
|
||||||
value = json.dumps(source)
|
value = json.dumps(source, indent=4)
|
||||||
except Exception:
|
except Exception:
|
||||||
try:
|
try:
|
||||||
value = str(source)
|
value = str(source)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user