Merge remote-tracking branch 'upstream/master' into multitalk

This commit is contained in:
kijai 2025-11-03 17:41:31 +02:00
commit 25063f25cc
26 changed files with 946 additions and 1084 deletions

View File

@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
@ -146,7 +147,7 @@ class PerformanceFeature(enum.Enum):
AutoTune = "autotune" AutoTune = "autotune"
PinnedMem = "pinned_memory" PinnedMem = "pinned_memory"
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature)))) parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")

View File

@ -522,7 +522,7 @@ class NextDiT(nn.Module):
max_cap_len = max(l_effective_cap_len) max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len) max_img_len = max(l_effective_img_len)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
for i in range(bsz): for i in range(bsz):
cap_len = l_effective_cap_len[i] cap_len = l_effective_cap_len[i]
@ -531,10 +531,22 @@ class NextDiT(nn.Module):
H_tokens, W_tokens = H // pH, W // pW H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len assert H_tokens * W_tokens == img_len
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)
h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

View File

@ -599,7 +599,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
patch_size = self.patch_size patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -612,10 +612,22 @@ class WanModel(torch.nn.Module):
if steps_w is None: if steps_w is None:
steps_w = w_len steps_w = w_len
h_start = 0
w_start = 0
rope_options = transformer_options.get("rope_options", None)
if rope_options is not None:
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
t_start += rope_options.get("shift_t", 0.0)
h_start += rope_options.get("shift_y", 0.0)
w_start += rope_options.get("shift_x", 0.0)
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
@ -641,7 +653,7 @@ class WanModel(torch.nn.Module):
if self.ref_conv is not None and "reference_latent" in kwargs: if self.ref_conv is not None and "reference_latent" in kwargs:
t_len += 1 t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):

View File

@ -276,6 +276,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory
@ -295,6 +298,7 @@ class ModelPatcher:
n.backup = self.backup n.backup = self.backup
n.object_patches_backup = self.object_patches_backup n.object_patches_backup = self.object_patches_backup
n.parent = self n.parent = self
n.pinned = self.pinned
n.force_cast_weights = self.force_cast_weights n.force_cast_weights = self.force_cast_weights
@ -451,6 +455,19 @@ class ModelPatcher:
def set_model_post_input_patch(self, patch): def set_model_post_input_patch(self, patch):
self.set_model_patch(patch, "post_input") self.set_model_patch(patch, "post_input")
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
rope_options = self.model_options["transformer_options"].get("rope_options", {})
rope_options["scale_x"] = scale_x
rope_options["scale_y"] = scale_y
rope_options["scale_t"] = scale_t
rope_options["shift_x"] = shift_x
rope_options["shift_y"] = shift_y
rope_options["shift_t"] = shift_t
self.model_options["transformer_options"]["rope_options"] = rope_options
def add_object_patch(self, name, obj): def add_object_patch(self, name, obj):
self.object_patches[name] = obj self.object_patches[name] = obj

View File

@ -84,7 +84,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None: if device is None:
device = input.device device = input.device
if offloadable: if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
offload_stream = comfy.model_management.get_offload_stream(device) offload_stream = comfy.model_management.get_offload_stream(device)
else: else:
offload_stream = None offload_stream = None
@ -94,20 +95,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
else: else:
wf_context = contextlib.nullcontext() wf_context = contextlib.nullcontext()
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
has_function = len(s.bias_function) > 0
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function: weight_has_function = len(s.weight_function) > 0
bias_has_function = len(s.bias_function) > 0
weight = comfy.model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream)
bias = None
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function:
with wf_context: with wf_context:
for f in s.bias_function: for f in s.bias_function:
bias = f(bias) bias = f(bias)
has_function = len(s.weight_function) > 0 weight = weight.to(dtype=dtype)
weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if weight_has_function:
if has_function:
with wf_context: with wf_context:
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
@ -401,15 +406,9 @@ def fp8_linear(self, input):
if dtype not in [torch.float8_e4m3fn]: if dtype not in [torch.float8_e4m3fn]:
return None return None
tensor_2d = False
if len(input.shape) == 2:
tensor_2d = True
input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype input_dtype = input.dtype
if len(input.shape) == 3: if input.ndim == 3 or input.ndim == 2:
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
scale_weight = self.scale_weight scale_weight = self.scale_weight
@ -422,24 +421,20 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
quantized_input = QuantizedTensor(input.reshape(-1, input_shape[2]).to(dtype).contiguous(), TensorCoreFP8Layout, layout_params_weight) quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
quantized_input = QuantizedTensor.from_float(input.reshape(-1, input_shape[2]), TensorCoreFP8Layout, scale=scale_input, dtype=dtype) quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
# Wrap weight in QuantizedTensor - this enables unified dispatch # Wrap weight in QuantizedTensor - this enables unified dispatch
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype}
quantized_weight = QuantizedTensor(w, TensorCoreFP8Layout, layout_params_weight) quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight)
o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) o = torch.nn.functional.linear(quantized_input, quantized_weight, bias)
uncast_bias_weight(self, w, bias, offload_stream) uncast_bias_weight(self, w, bias, offload_stream)
return o
if tensor_2d:
return o.reshape(input_shape[0], -1)
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None
@ -540,12 +535,12 @@ if CUBLAS_IS_AVAILABLE:
# ============================================================================== # ==============================================================================
# Mixed Precision Operations # Mixed Precision Operations
# ============================================================================== # ==============================================================================
from .quant_ops import QuantizedTensor, TensorCoreFP8Layout from .quant_ops import QuantizedTensor
QUANT_FORMAT_MIXINS = { QUANT_FORMAT_MIXINS = {
"float8_e4m3fn": { "float8_e4m3fn": {
"dtype": torch.float8_e4m3fn, "dtype": torch.float8_e4m3fn,
"layout_type": TensorCoreFP8Layout, "layout_type": "TensorCoreFP8Layout",
"parameters": { "parameters": {
"weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "weight_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),
"input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False), "input_scale": torch.nn.Parameter(torch.zeros((), dtype=torch.float32), requires_grad=False),

View File

@ -123,7 +123,7 @@ class QuantizedTensor(torch.Tensor):
layout_type: Layout class (subclass of QuantizedLayout) layout_type: Layout class (subclass of QuantizedLayout)
layout_params: Dict with layout-specific parameters layout_params: Dict with layout-specific parameters
""" """
return torch.Tensor._make_subclass(cls, qdata, require_grad=False) return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False)
def __init__(self, qdata, layout_type, layout_params): def __init__(self, qdata, layout_type, layout_params):
self._qdata = qdata.contiguous() self._qdata = qdata.contiguous()
@ -183,11 +183,11 @@ class QuantizedTensor(torch.Tensor):
@classmethod @classmethod
def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor':
qdata, layout_params = layout_type.quantize(tensor, **quantize_kwargs) qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs)
return cls(qdata, layout_type, layout_params) return cls(qdata, layout_type, layout_params)
def dequantize(self) -> torch.Tensor: def dequantize(self) -> torch.Tensor:
return self._layout_type.dequantize(self._qdata, **self._layout_params) return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params)
@classmethod @classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
@ -379,7 +379,12 @@ class TensorCoreFP8Layout(QuantizedLayout):
return qtensor._qdata, qtensor._layout_params['scale'] return qtensor._qdata, qtensor._layout_params['scale']
@register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) LAYOUTS = {
"TensorCoreFP8Layout": TensorCoreFP8Layout,
}
@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout")
def fp8_linear(func, args, kwargs): def fp8_linear(func, args, kwargs):
input_tensor = args[0] input_tensor = args[0]
weight = args[1] weight = args[1]
@ -422,7 +427,7 @@ def fp8_linear(func, args, kwargs):
'scale': output_scale, 'scale': output_scale,
'orig_dtype': input_tensor._layout_params['orig_dtype'] 'orig_dtype': input_tensor._layout_params['orig_dtype']
} }
return QuantizedTensor(output, TensorCoreFP8Layout, output_params) return QuantizedTensor(output, "TensorCoreFP8Layout", output_params)
else: else:
return output return output
@ -436,3 +441,15 @@ def fp8_linear(func, args, kwargs):
input_tensor = input_tensor.dequantize() input_tensor = input_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight, bias) return torch.nn.functional.linear(input_tensor, weight, bias)
@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout")
@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout")
def fp8_func(func, args, kwargs):
input_tensor = args[0]
if isinstance(input_tensor, QuantizedTensor):
plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor)
ar = list(args)
ar[0] = plain_input
return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params)
return func(*args, **kwargs)

View File

@ -143,6 +143,9 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -293,6 +296,7 @@ class VAE:
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False self.disable_offload = False
self.not_video = False self.not_video = False
self.size = None
self.downscale_index_formula = None self.downscale_index_formula = None
self.upscale_index_formula = None self.upscale_index_formula = None
@ -595,6 +599,16 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()
def model_size(self):
if self.size is not None:
return self.size
self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:

View File

@ -1,73 +0,0 @@
from __future__ import annotations
import aiohttp
import mimetypes
from typing import Union
from server import PromptServer
import numpy as np
from PIL import Image
import torch
import base64
from io import BytesIO
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
response: The response to validate and cast.
timeout: Request timeout in seconds. Defaults to None (no timeout).
Returns:
A torch.Tensor representing the image (1, H, W, C).
Raises:
ValueError: If the response is not valid.
"""
# validate raw JSON response
data = response.data
if not data or len(data) == 0:
raise ValueError("No images returned from API endpoint")
# Initialize list to store image tensors
image_tensors: list[torch.Tensor] = []
# Process each image in the data array
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload neither URL nor base64 data present.")
pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))
return torch.stack(image_tensors, dim=0)
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@ -46,7 +46,7 @@ class TextToVideoNode(IO.ComfyNode):
multiline=True, multiline=True,
default="", default="",
), ),
IO.Combo.Input("duration", options=[6, 8, 10], default=8), IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=[ options=[
@ -85,6 +85,10 @@ class TextToVideoNode(IO.ComfyNode):
generate_audio: bool = False, generate_audio: bool = False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000) validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
response = await sync_op_raw( response = await sync_op_raw(
cls, cls,
ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"),
@ -118,7 +122,7 @@ class ImageToVideoNode(IO.ComfyNode):
multiline=True, multiline=True,
default="", default="",
), ),
IO.Combo.Input("duration", options=[6, 8, 10], default=8), IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8),
IO.Combo.Input( IO.Combo.Input(
"resolution", "resolution",
options=[ options=[
@ -158,6 +162,10 @@ class ImageToVideoNode(IO.ComfyNode):
generate_audio: bool = False, generate_audio: bool = False,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=10000) validate_string(prompt, min_length=1, max_length=10000)
if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25):
raise ValueError(
"Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS."
)
if get_number_of_images(image) != 1: if get_number_of_images(image) != 1:
raise ValueError("Currently only one input image is supported.") raise ValueError("Currently only one input image is supported.")
response = await sync_op_raw( response = await sync_op_raw(

File diff suppressed because it is too large Load Diff

View File

@ -7,24 +7,23 @@ from __future__ import annotations
from io import BytesIO from io import BytesIO
import logging import logging
from typing import Optional, TypeVar from typing import Optional
import torch import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO from comfy_api.latest import ComfyExtension, IO
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apis import pika_defs from comfy_api_nodes.apis import pika_api as pika_defs
from comfy_api_nodes.apis.client import ( from comfy_api_nodes.util import (
validate_string,
download_url_to_video_output,
tensor_to_bytesio,
ApiEndpoint, ApiEndpoint,
EmptyRequest, sync_op,
HttpMethod, poll_op,
PollingOperation,
SynchronousOperation,
) )
from comfy_api_nodes.util import validate_string, download_url_to_video_output, tensor_to_bytesio
R = TypeVar("R")
PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
@ -40,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos"
async def execute_task( async def execute_task(
initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse], task_id: str,
auth_kwargs: Optional[dict[str, str]] = None, cls: type[IO.ComfyNode],
node_id: Optional[str] = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
task_id = (await initial_operation.execute()).video_id final_response: pika_defs.PikaVideoResponse = await poll_op(
final_response: pika_defs.PikaVideoResponse = await PollingOperation( cls,
poll_endpoint=ApiEndpoint( ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"),
path=f"{PATH_VIDEO_GET}/{task_id}", response_model=pika_defs.PikaVideoResponse,
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=pika_defs.PikaVideoResponse,
),
completed_statuses=["finished"],
failed_statuses=["failed", "cancelled"],
status_extractor=lambda response: (response.status.value if response.status else None), status_extractor=lambda response: (response.status.value if response.status else None),
progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None),
node_id=node_id,
estimated_duration=60, estimated_duration=60,
max_poll_attempts=240, max_poll_attempts=240,
).execute() )
if not final_response.url: if not final_response.url:
error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}"
logging.error(error_msg) logging.error(error_msg)
@ -124,23 +113,15 @@ class PikaImageToVideo(IO.ComfyNode):
resolution=resolution, resolution=resolution,
duration=duration, duration=duration,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaTextToVideoNode(IO.ComfyNode): class PikaTextToVideoNode(IO.ComfyNode):
@ -183,18 +164,11 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration: int, duration: int,
aspect_ratio: float, aspect_ratio: float,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_VIDEO,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -202,10 +176,9 @@ class PikaTextToVideoNode(IO.ComfyNode):
duration=duration, duration=duration,
aspectRatio=aspect_ratio, aspectRatio=aspect_ratio,
), ),
auth_kwargs=auth,
content_type="application/x-www-form-urlencoded", content_type="application/x-www-form-urlencoded",
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaScenes(IO.ComfyNode): class PikaScenes(IO.ComfyNode):
@ -309,24 +282,16 @@ class PikaScenes(IO.ComfyNode):
duration=duration, duration=duration,
aspectRatio=aspect_ratio, aspectRatio=aspect_ratio,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKASCENES, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKASCENES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikAdditionsNode(IO.ComfyNode): class PikAdditionsNode(IO.ComfyNode):
@ -383,24 +348,16 @@ class PikAdditionsNode(IO.ComfyNode):
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKADDITIONS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaSwapsNode(IO.ComfyNode): class PikaSwapsNode(IO.ComfyNode):
@ -472,23 +429,15 @@ class PikaSwapsNode(IO.ComfyNode):
seed=seed, seed=seed,
modifyRegionRoi=region_to_modify if region_to_modify else None, modifyRegionRoi=region_to_modify if region_to_modify else None,
) )
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKASWAPS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_request_data,
endpoint=ApiEndpoint(
path=PATH_PIKASWAPS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_request_data,
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaffectsNode(IO.ComfyNode): class PikaffectsNode(IO.ComfyNode):
@ -528,18 +477,11 @@ class PikaffectsNode(IO.ComfyNode):
negative_prompt: str, negative_prompt: str,
seed: int, seed: int,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
endpoint=ApiEndpoint(
path=PATH_PIKAFFECTS,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
pikaffect=pikaffect, pikaffect=pikaffect,
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
@ -547,9 +489,8 @@ class PikaffectsNode(IO.ComfyNode):
), ),
files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaStartEndFrameNode(IO.ComfyNode): class PikaStartEndFrameNode(IO.ComfyNode):
@ -592,18 +533,11 @@ class PikaStartEndFrameNode(IO.ComfyNode):
("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")),
("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
] ]
auth = { initial_operation = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"),
} response_model=pika_defs.PikaGenerateResponse,
initial_operation = SynchronousOperation( data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
endpoint=ApiEndpoint(
path=PATH_PIKAFRAMES,
method=HttpMethod.POST,
request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
response_model=pika_defs.PikaGenerateResponse,
),
request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
promptText=prompt_text, promptText=prompt_text,
negativePrompt=negative_prompt, negativePrompt=negative_prompt,
seed=seed, seed=seed,
@ -612,9 +546,8 @@ class PikaStartEndFrameNode(IO.ComfyNode):
), ),
files=pika_files, files=pika_files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) return await execute_task(initial_operation.video_id, cls)
class PikaApiNodesExtension(ComfyExtension): class PikaApiNodesExtension(ComfyExtension):

View File

@ -225,21 +225,20 @@ async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] =
async def download_files(url_list, task_uuid): async def download_files(url_list, task_uuid):
save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") result_folder_name = f"Rodin3D_{task_uuid}"
save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name)
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
model_file_path = None model_file_path = None
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
for i in url_list.list: for i in url_list.list:
url = i.url file_path = os.path.join(save_path, i.name)
file_name = i.name
file_path = os.path.join(save_path, file_name)
if file_path.endswith(".glb"): if file_path.endswith(".glb"):
model_file_path = file_path model_file_path = os.path.join(result_folder_name, i.name)
logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path)
max_retries = 5 max_retries = 5
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
async with session.get(url) as resp: async with session.get(i.url) as resp:
resp.raise_for_status() resp.raise_for_status()
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
async for chunk in resp.content.iter_chunked(32 * 1024): async for chunk in resp.content.iter_chunked(32 * 1024):

View File

@ -20,13 +20,6 @@ from comfy_api_nodes.apis.stability_api import (
StabilityAudioInpaintRequest, StabilityAudioInpaintRequest,
StabilityAudioResponse, StabilityAudioResponse,
) )
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
validate_audio_duration, validate_audio_duration,
validate_string, validate_string,
@ -34,6 +27,9 @@ from comfy_api_nodes.util import (
bytesio_to_image_tensor, bytesio_to_image_tensor,
tensor_to_bytesio, tensor_to_bytesio,
audio_bytes_to_audio_input, audio_bytes_to_audio_input,
sync_op,
poll_op,
ApiEndpoint,
) )
import torch import torch
@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityStableUltraRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/ultra",
method=HttpMethod.POST,
request_model=StabilityStableUltraRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityStableUltraRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityStable3_5Request(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/generate/sd3",
method=HttpMethod.POST,
request_model=StabilityStable3_5Request,
response_model=StabilityStableUltraResponse,
),
request=StabilityStable3_5Request(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
aspect_ratio=aspect_ratio, aspect_ratio=aspect_ratio,
@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
} response_model=StabilityStableUltraResponse,
data=StabilityUpscaleConservativeRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
method=HttpMethod.POST,
request_model=StabilityUpscaleConservativeRequest,
response_model=StabilityStableUltraResponse,
),
request=StabilityUpscaleConservativeRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
creativity=round(creativity,2), creativity=round(creativity,2),
@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
} response_model=StabilityAsyncResponse,
data=StabilityUpscaleCreativeRequest(
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/creative",
method=HttpMethod.POST,
request_model=StabilityUpscaleCreativeRequest,
response_model=StabilityAsyncResponse,
),
request=StabilityUpscaleCreativeRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
creativity=round(creativity,2), creativity=round(creativity,2),
@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
), ),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
operation = PollingOperation( response_poll = await poll_op(
poll_endpoint=ApiEndpoint( cls,
path=f"/proxy/stability/v2beta/results/{response_api.id}", ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
method=HttpMethod.GET, response_model=StabilityResultsGetResponse,
request_model=EmptyRequest,
response_model=StabilityResultsGetResponse,
),
poll_interval=3, poll_interval=3,
completed_statuses=[StabilityPollStatus.finished],
failed_statuses=[StabilityPollStatus.failed],
status_extractor=lambda x: get_async_dummy_status(x), status_extractor=lambda x: get_async_dummy_status(x),
auth_kwargs=auth,
node_id=cls.hidden.unique_id,
) )
response_poll: StabilityResultsGetResponse = await operation.execute()
if response_poll.finish_reason != "SUCCESS": if response_poll.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
"image": image_binary "image": image_binary
} }
auth = { response_api = await sync_op(
"auth_token": cls.hidden.auth_token_comfy_org, cls,
"comfy_api_key": cls.hidden.api_key_comfy_org, ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
} response_model=StabilityStableUltraResponse,
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/stability/v2beta/stable-image/upscale/fast",
method=HttpMethod.POST,
request_model=EmptyRequest,
response_model=StabilityStableUltraResponse,
),
request=EmptyRequest(),
files=files, files=files,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs=auth,
) )
response_api = await operation.execute()
if response_api.finish_reason != "SUCCESS": if response_api.finish_reason != "SUCCESS":
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode):
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
validate_string(prompt, max_length=10000) validate_string(prompt, max_length=10000)
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityTextToAudioRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode):
payload = StabilityAudioToAudioRequest( payload = StabilityAudioToAudioRequest(
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
) )
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityAudioToAudioRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)}, files={"audio": audio_input_to_mp3(audio)},
auth_kwargs= {
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode):
mask_start=mask_start, mask_start=mask_start,
mask_end=mask_end, mask_end=mask_end,
) )
operation = SynchronousOperation( response_api = await sync_op(
endpoint=ApiEndpoint( cls,
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
method=HttpMethod.POST, response_model=StabilityAudioResponse,
request_model=StabilityAudioInpaintRequest, data=payload,
response_model=StabilityAudioResponse,
),
request=payload,
content_type="multipart/form-data", content_type="multipart/form-data",
files={"audio": audio_input_to_mp3(audio)}, files={"audio": audio_input_to_mp3(audio)},
auth_kwargs={
"auth_token": cls.hidden.auth_token_comfy_org,
"comfy_api_key": cls.hidden.api_key_comfy_org,
},
) )
response_api = await operation.execute()
if not response_api.audio: if not response_api.audio:
raise ValueError("No audio file was received in response.") raise ValueError("No audio file was received in response.")
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

View File

@ -18,6 +18,8 @@ from .conversions import (
tensor_to_base64_string, tensor_to_base64_string,
tensor_to_bytesio, tensor_to_bytesio,
tensor_to_pil, tensor_to_pil,
text_filepath_to_base64_string,
text_filepath_to_data_uri,
trim_video, trim_video,
video_to_base64_string, video_to_base64_string,
) )
@ -75,6 +77,8 @@ __all__ = [
"tensor_to_base64_string", "tensor_to_base64_string",
"tensor_to_bytesio", "tensor_to_bytesio",
"tensor_to_pil", "tensor_to_pil",
"text_filepath_to_base64_string",
"text_filepath_to_data_uri",
"trim_video", "trim_video",
"video_to_base64_string", "video_to_base64_string",
# Validation utilities # Validation utilities

View File

@ -77,7 +77,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 429, 500, 502, 503, 504} _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"] COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"] FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"] QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
@ -589,7 +589,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
payload_headers = {"Accept": "*/*"} payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
payload_headers.update(get_auth_header(cfg.node_cls)) payload_headers.update(get_auth_header(cfg.node_cls))
if cfg.endpoint.headers: if cfg.endpoint.headers:

View File

@ -1,6 +1,7 @@
import base64 import base64
import logging import logging
import math import math
import mimetypes
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
@ -12,7 +13,7 @@ from PIL import Image
from comfy.utils import common_upscale from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl from comfy_api.latest import Input, InputImpl
from comfy_api.util import VideoContainer, VideoCodec from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension from ._helpers import mimetype_to_extension
@ -451,3 +452,19 @@ def resize_mask_to_image(
if not allow_gradient: if not allow_gradient:
mask = (mask > 0.5).float() mask = (mask > 0.5).float()
return mask return mask
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"

View File

@ -1,4 +1,9 @@
import bisect
import gc
import itertools import itertools
import psutil
import time
import torch
from typing import Sequence, Mapping, Dict from typing import Sequence, Mapping, Dict
from comfy_execution.graph import DynamicPrompt from comfy_execution.graph import DynamicPrompt
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -48,7 +53,7 @@ class Unhashable:
def to_hashable(obj): def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples # So that we don't infinitely recurse since frozenset and tuples
# are Sequences. # are Sequences.
if isinstance(obj, (int, float, str, bool, type(None))): if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj return obj
elif isinstance(obj, Mapping): elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
@ -188,6 +193,9 @@ class BasicCache:
self._clean_cache() self._clean_cache()
self._clean_subcaches() self._clean_subcaches()
def poll(self, **kwargs):
pass
def _set_immediate(self, node_id, value): def _set_immediate(self, node_id, value):
assert self.initialized assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
@ -276,6 +284,9 @@ class NullCache:
def clean_unused(self): def clean_unused(self):
pass pass
def poll(self, **kwargs):
pass
def get(self, node_id): def get(self, node_id):
return None return None
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
self._mark_used(child_id) self._mark_used(child_id)
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
RAM_CACHE_HYSTERESIS = 1.1
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups.
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache):
def __init__(self, key_class):
super().__init__(key_class, 0)
self.timestamps = {}
def clean_unused(self):
self._clean_subcaches()
def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value)
def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id)
def poll(self, ram_headroom):
def _ram_gb():
return psutil.virtual_memory().available / (1024**3)
if _ram_gb() > ram_headroom:
return
gc.collect()
if _ram_gb() > ram_headroom:
return
clean_list = []
for key, (outputs, _), in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
def scan_list_for_ram_usage(outputs):
nonlocal ram_usage
for output in outputs:
if isinstance(output, list):
scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to
#be high value intermediates
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
_, _, key = clean_list.pop()
del self.cache[key]
gc.collect()

View File

@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id) self.execution_cache_listeners[from_node_id].add(to_node_id)
def get_output_cache(self, from_node_id, to_node_id): def get_cache(self, from_node_id, to_node_id):
if not to_node_id in self.execution_cache: if not to_node_id in self.execution_cache:
return None return None
return self.execution_cache[to_node_id].get(from_node_id) value = self.execution_cache[to_node_id].get(from_node_id)
if value is None:
return None
#Write back to the main cache on touch.
self.output_cache.set(from_node_id, value)
return value
def cache_update(self, node_id, value): def cache_update(self, node_id, value):
if node_id in self.execution_cache_listeners: if node_id in self.execution_cache_listeners:

View File

@ -2,6 +2,9 @@ import comfy.utils
import folder_paths import folder_paths
import torch import torch
import logging import logging
from comfy_api.latest import IO, ComfyExtension
from typing_extensions import override
def load_hypernetwork_patch(path, strength): def load_hypernetwork_patch(path, strength):
sd = comfy.utils.load_torch_file(path, safe_load=True) sd = comfy.utils.load_torch_file(path, safe_load=True)
@ -94,27 +97,42 @@ def load_hypernetwork_patch(path, strength):
return hypernetwork_patch(out, strength) return hypernetwork_patch(out, strength)
class HypernetworkLoader: class HypernetworkLoader(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "model": ("MODEL",), return IO.Schema(
"hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"), ), node_id="HypernetworkLoader",
"strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), category="loaders",
}} inputs=[
RETURN_TYPES = ("MODEL",) IO.Model.Input("model"),
FUNCTION = "load_hypernetwork" IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01),
],
outputs=[
IO.Model.Output(),
],
)
CATEGORY = "loaders" @classmethod
def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput:
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone() model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength) patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None: if patch is not None:
model_hypernetwork.set_model_attn1_patch(patch) model_hypernetwork.set_model_attn1_patch(patch)
model_hypernetwork.set_model_attn2_patch(patch) model_hypernetwork.set_model_attn2_patch(patch)
return (model_hypernetwork,) return IO.NodeOutput(model_hypernetwork)
NODE_CLASS_MAPPINGS = { load_hypernetwork = execute # TODO: remove
"HypernetworkLoader": HypernetworkLoader
}
class HyperNetworkExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
HypernetworkLoader,
]
async def comfy_entrypoint() -> HyperNetworkExtension:
return HyperNetworkExtension()

View File

@ -0,0 +1,47 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class ScaleROPE(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ScaleROPE",
category="advanced/model_patches",
description="Scale and shift the ROPE of the model.",
is_experimental=True,
inputs=[
io.Model.Input("model"),
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
],
outputs=[
io.Model.Output(),
],
)
@classmethod
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
m = model.clone()
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
return io.NodeOutput(m)
class RopeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
ScaleROPE
]
async def comfy_entrypoint() -> RopeExtension:
return RopeExtension()

View File

@ -21,6 +21,7 @@ from comfy_execution.caching import (
NullCache, NullCache,
HierarchicalCache, HierarchicalCache,
LRUCache, LRUCache,
RAMPressureCache,
) )
from comfy_execution.graph import ( from comfy_execution.graph import (
DynamicPrompt, DynamicPrompt,
@ -88,49 +89,56 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
class CacheEntry(NamedTuple):
ui: dict
outputs: list
class CacheType(Enum): class CacheType(Enum):
CLASSIC = 0 CLASSIC = 0
LRU = 1 LRU = 1
NONE = 2 NONE = 2
RAM_PRESSURE = 3
class CacheSet: class CacheSet:
def __init__(self, cache_type=None, cache_size=None): def __init__(self, cache_type=None, cache_args={}):
if cache_type == CacheType.NONE: if cache_type == CacheType.NONE:
self.init_null_cache() self.init_null_cache()
logging.info("Disabling intermediate node cache.") logging.info("Disabling intermediate node cache.")
elif cache_type == CacheType.RAM_PRESSURE:
cache_ram = cache_args.get("ram", 16.0)
self.init_ram_cache(cache_ram)
logging.info("Using RAM pressure cache.")
elif cache_type == CacheType.LRU: elif cache_type == CacheType.LRU:
if cache_size is None: cache_size = cache_args.get("lru", 0)
cache_size = 0
self.init_lru_cache(cache_size) self.init_lru_cache(cache_size)
logging.info("Using LRU cache") logging.info("Using LRU cache")
else: else:
self.init_classic_cache() self.init_classic_cache()
self.all = [self.outputs, self.ui, self.objects] self.all = [self.outputs, self.objects]
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature)
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
self.outputs = NullCache() self.outputs = NullCache()
#The UI cache is expected to be iterable at the end of each workflow
#so it must cache at least a full workflow. Use Heirachical
self.ui = HierarchicalCache(CacheKeySetInputSignature)
self.objects = NullCache() self.objects = NullCache()
def recursive_debug_dump(self): def recursive_debug_dump(self):
result = { result = {
"outputs": self.outputs.recursive_debug_dump(), "outputs": self.outputs.recursive_debug_dump(),
"ui": self.ui.recursive_debug_dump(),
} }
return result return result
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if execution_list is None: if execution_list is None:
mark_missing() mark_missing()
continue # This might be a lazily-evaluated input continue # This might be a lazily-evaluated input
cached_output = execution_list.get_output_cache(input_unique_id, unique_id) cached = execution_list.get_cache(input_unique_id, unique_id)
if cached_output is None: if cached is None or cached.outputs is None:
mark_missing() mark_missing()
continue continue
if output_index >= len(cached_output): if output_index >= len(cached.outputs):
mark_missing() mark_missing()
continue continue
obj = cached_output[output_index] obj = cached.outputs[output_index]
input_data_all[x] = obj input_data_all[x] = obj
elif input_category is not None: elif input_category is not None:
input_data_all[x] = [input_data] input_data_all[x] = [input_data]
@ -393,7 +401,7 @@ def format_value(x):
else: else:
return str(x) return str(x)
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id)
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if caches.outputs.get(unique_id) is not None: cached = caches.outputs.get(unique_id)
if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_output = caches.ui.get(unique_id) or {} cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id) get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, caches.outputs.get(unique_id)) execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
input_data_all = None input_data_all = None
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
for r in result: for r in result:
if is_link(r): if is_link(r):
source_node, source_output = r[0], r[1] source_node, source_output = r[0], r[1]
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output] node_cached = execution_list.get_cache(source_node, unique_id)
for o in node_output: for o in node_cached.outputs[source_output]:
resolved_output.append(o) resolved_output.append(o)
else: else:
@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
asyncio.create_task(await_completion()) asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0: if len(output_ui) > 0:
caches.ui.set(unique_id, { ui_outputs[unique_id] = {
"meta": { "meta": {
"node_id": unique_id, "node_id": unique_id,
"display_node": display_node_id, "display_node": display_node_id,
@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
"real_node_id": real_node_id, "real_node_id": real_node_id,
}, },
"output": output_ui "output": output_ui
}) }
if server.client_id is not None: if server.client_id is not None:
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id) server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
if has_subgraph: if has_subgraph:
@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
pending_subgraph_results[unique_id] = cached_outputs pending_subgraph_results[unique_id] = cached_outputs
return (ExecutionResult.PENDING, None, None) return (ExecutionResult.PENDING, None, None)
caches.outputs.set(unique_id, output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, output_data) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
class PromptExecutor: class PromptExecutor:
def __init__(self, server, cache_type=False, cache_size=None): def __init__(self, server, cache_type=False, cache_args=None):
self.cache_size = cache_size self.cache_args = cache_args
self.cache_type = cache_type self.cache_type = cache_type
self.server = server self.server = server
self.reset() self.reset()
def reset(self): def reset(self):
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
self.status_messages = [] self.status_messages = []
self.success = True self.success = True
@ -682,6 +694,7 @@ class PromptExecutor:
broadcast=False) broadcast=False)
pending_subgraph_results = {} pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set() executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids() current_outputs = self.caches.outputs.all_node_ids()
@ -695,7 +708,7 @@ class PromptExecutor:
break break
assert node_id is not None, "Node ID should not be None at this point" assert node_id is not None, "Node ID should not be None at this point"
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
self.success = result != ExecutionResult.FAILURE self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE: if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
@ -704,18 +717,16 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {} ui_outputs = {}
meta_outputs = {} meta_outputs = {}
all_node_ids = self.caches.ui.all_node_ids() for node_id, ui_info in ui_node_outputs.items():
for node_id in all_node_ids: ui_outputs[node_id] = ui_info["output"]
ui_info = self.caches.ui.get(node_id) meta_outputs[node_id] = ui_info["meta"]
if ui_info is not None:
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = { self.history_result = {
"outputs": ui_outputs, "outputs": ui_outputs,
"meta": meta_outputs, "meta": meta_outputs,

View File

@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif args.cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -2329,6 +2329,7 @@ async def init_builtin_extra_nodes():
"nodes_model_patch.py", "nodes_model_patch.py",
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",
"nodes_rope.py",
] ]
import_failed = [] import_failed = []

View File

@ -14,7 +14,7 @@ if not has_gpu():
args.cpu = True args.cpu = True
from comfy import ops from comfy import ops
from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout from comfy.quant_ops import QuantizedTensor
class SimpleModel(torch.nn.Module): class SimpleModel(torch.nn.Module):
@ -104,14 +104,14 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify weights are wrapped in QuantizedTensor # Verify weights are wrapped in QuantizedTensor
self.assertIsInstance(model.layer1.weight, QuantizedTensor) self.assertIsInstance(model.layer1.weight, QuantizedTensor)
self.assertEqual(model.layer1.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout")
# Layer 2 should NOT be quantized # Layer 2 should NOT be quantized
self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) self.assertNotIsInstance(model.layer2.weight, QuantizedTensor)
# Layer 3 should be quantized # Layer 3 should be quantized
self.assertIsInstance(model.layer3.weight, QuantizedTensor) self.assertIsInstance(model.layer3.weight, QuantizedTensor)
self.assertEqual(model.layer3.weight._layout_type, TensorCoreFP8Layout) self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout")
# Verify scales were loaded # Verify scales were loaded
self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0)
@ -155,7 +155,7 @@ class TestMixedPrecisionOps(unittest.TestCase):
# Verify layer1.weight is a QuantizedTensor with scale preserved # Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_type, TensorCoreFP8Layout) self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout")
# Verify non-quantized layers are standard tensors # Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)

View File

@ -25,14 +25,14 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(2.0) scale = torch.tensor(2.0)
layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
self.assertIsInstance(qt, QuantizedTensor) self.assertIsInstance(qt, QuantizedTensor)
self.assertEqual(qt.shape, (256, 128)) self.assertEqual(qt.shape, (256, 128))
self.assertEqual(qt.dtype, torch.float8_e4m3fn) self.assertEqual(qt.dtype, torch.float8_e4m3fn)
self.assertEqual(qt._layout_params['scale'], scale) self.assertEqual(qt._layout_params['scale'], scale)
self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16)
self.assertEqual(qt._layout_type, TensorCoreFP8Layout) self.assertEqual(qt._layout_type, "TensorCoreFP8Layout")
def test_dequantize(self): def test_dequantize(self):
"""Test explicit dequantization""" """Test explicit dequantization"""
@ -41,7 +41,7 @@ class TestQuantizedTensor(unittest.TestCase):
scale = torch.tensor(3.0) scale = torch.tensor(3.0)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
dequantized = qt.dequantize() dequantized = qt.dequantize()
self.assertEqual(dequantized.dtype, torch.float32) self.assertEqual(dequantized.dtype, torch.float32)
@ -54,7 +54,7 @@ class TestQuantizedTensor(unittest.TestCase):
qt = QuantizedTensor.from_float( qt = QuantizedTensor.from_float(
float_tensor, float_tensor,
TensorCoreFP8Layout, "TensorCoreFP8Layout",
scale=scale, scale=scale,
dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )
@ -77,28 +77,28 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Detach should return a new QuantizedTensor # Detach should return a new QuantizedTensor
qt_detached = qt.detach() qt_detached = qt.detach()
self.assertIsInstance(qt_detached, QuantizedTensor) self.assertIsInstance(qt_detached, QuantizedTensor)
self.assertEqual(qt_detached.shape, qt.shape) self.assertEqual(qt_detached.shape, qt.shape)
self.assertEqual(qt_detached._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout")
def test_clone(self): def test_clone(self):
"""Test clone operation on quantized tensor""" """Test clone operation on quantized tensor"""
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Clone should return a new QuantizedTensor # Clone should return a new QuantizedTensor
qt_cloned = qt.clone() qt_cloned = qt.clone()
self.assertIsInstance(qt_cloned, QuantizedTensor) self.assertIsInstance(qt_cloned, QuantizedTensor)
self.assertEqual(qt_cloned.shape, qt.shape) self.assertEqual(qt_cloned.shape, qt.shape)
self.assertEqual(qt_cloned._layout_type, TensorCoreFP8Layout) self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout")
# Verify it's a deep copy # Verify it's a deep copy
self.assertIsNot(qt_cloned._qdata, qt._qdata) self.assertIsNot(qt_cloned._qdata, qt._qdata)
@ -109,7 +109,7 @@ class TestGenericUtilities(unittest.TestCase):
fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn)
scale = torch.tensor(1.5) scale = torch.tensor(1.5)
layout_params = {'scale': scale, 'orig_dtype': torch.float32} layout_params = {'scale': scale, 'orig_dtype': torch.float32}
qt = QuantizedTensor(fp8_data, TensorCoreFP8Layout, layout_params) qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params)
# Moving to same device should work (CPU to CPU) # Moving to same device should work (CPU to CPU)
qt_cpu = qt.to('cpu') qt_cpu = qt.to('cpu')
@ -169,7 +169,7 @@ class TestFallbackMechanism(unittest.TestCase):
scale = torch.tensor(1.0) scale = torch.tensor(1.0)
a_q = QuantizedTensor.from_float( a_q = QuantizedTensor.from_float(
a_fp32, a_fp32,
TensorCoreFP8Layout, "TensorCoreFP8Layout",
scale=scale, scale=scale,
dtype=torch.float8_e4m3fn dtype=torch.float8_e4m3fn
) )