mirror of
https://git.datalinker.icu/kijai/ComfyUI-KJNodes.git
synced 2025-12-09 12:54:40 +08:00
Compare commits
14 Commits
567e05f7d6
...
de6031f606
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de6031f606 | ||
|
|
50e7dd34d3 | ||
|
|
37206374ef | ||
|
|
06a60ac3fe | ||
|
|
a7ce03e735 | ||
|
|
f37df472df | ||
|
|
390d05fe7e | ||
|
|
f0ed965cd9 | ||
|
|
acdd16a973 | ||
|
|
4dfb85dcc5 | ||
|
|
8660778ea1 | ||
|
|
3b9c1b49ab | ||
|
|
246920d8b9 | ||
|
|
363c90f4e3 |
@ -209,6 +209,12 @@ NODE_CONFIG = {
|
||||
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
|
||||
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
|
||||
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
|
||||
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
|
||||
"NABLA_AttentionKJ": {"class": NABLA_AttentionKJ, "name": "NABLA Attention KJ"},
|
||||
"TorchCompileModelAdvanced": {"class": TorchCompileModelAdvanced, "name": "TorchCompileModelAdvanced"},
|
||||
"StartRecordCUDAMemoryHistory": {"class": StartRecordCUDAMemoryHistory, "name": "Start Recording CUDAMemory History"},
|
||||
"EndRecordCUDAMemoryHistory": {"class": EndRecordCUDAMemoryHistory, "name": "End Recording CUDAMemory History"},
|
||||
"VisualizeCUDAMemoryHistory": {"class": VisualizeCUDAMemoryHistory, "name": "Visualize CUDAMemory History"},
|
||||
|
||||
#instance diffusion
|
||||
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},
|
||||
|
||||
@ -176,7 +176,7 @@ Saves an image and mask as .PNG with the mask as the alpha channel.
|
||||
def file_counter():
|
||||
max_counter = 0
|
||||
# Loop through the existing files
|
||||
for existing_file in os.listdir(full_output_folder):
|
||||
for existing_file in sorted(os.listdir(full_output_folder)):
|
||||
# Check if the file matches the expected format
|
||||
match = re.fullmatch(fr"{filename}_(\d+)_?\.[a-zA-Z0-9]+", existing_file)
|
||||
if match:
|
||||
@ -2981,7 +2981,7 @@ class LoadImagesFromFolderKJ:
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
for file in os.listdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
if any(file.lower().endswith(ext) for ext in valid_extensions):
|
||||
path = os.path.join(folder, file)
|
||||
try:
|
||||
@ -3043,7 +3043,7 @@ class LoadImagesFromFolderKJ:
|
||||
if any(file.lower().endswith(ext) for ext in valid_extensions):
|
||||
image_paths.append(os.path.join(root, file))
|
||||
else:
|
||||
for file in os.listdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
if any(file.lower().endswith(ext) for ext in valid_extensions):
|
||||
image_paths.append(os.path.join(folder, file))
|
||||
|
||||
@ -3964,7 +3964,7 @@ class LoadVideosFromFolder:
|
||||
raise ImportError("This node requires ComfyUI-VideoHelperSuite to be installed.")
|
||||
videos_list = []
|
||||
filenames = []
|
||||
for f in os.listdir(kwargs['video']):
|
||||
for f in sorted(os.listdir(kwargs['video'])):
|
||||
if os.path.isfile(os.path.join(kwargs['video'], f)):
|
||||
file_parts = f.split('.')
|
||||
if len(file_parts) > 1 and (file_parts[-1].lower() in ['webm', 'mp4', 'mkv', 'gif', 'mov']):
|
||||
|
||||
@ -3,15 +3,18 @@ from comfy.ldm.modules import attention as comfy_attention
|
||||
import logging
|
||||
import torch
|
||||
import importlib
|
||||
import math
|
||||
import datetime
|
||||
|
||||
import folder_paths
|
||||
import comfy.model_management as mm
|
||||
from comfy.cli_args import args
|
||||
from comfy.ldm.modules.attention import wrap_attn
|
||||
from comfy.ldm.modules.attention import wrap_attn, optimized_attention
|
||||
import comfy.model_patcher
|
||||
import comfy.utils
|
||||
import comfy.sd
|
||||
|
||||
|
||||
try:
|
||||
from comfy_api.latest import io
|
||||
v3_available = True
|
||||
@ -71,6 +74,9 @@ def get_sage_func(sage_attention, allow_compile=False):
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
in_dtype = v.dtype
|
||||
if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
|
||||
q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
@ -89,7 +95,7 @@ def get_sage_func(sage_attention, allow_compile=False):
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype)
|
||||
if tensor_layout == "HND":
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
@ -675,6 +681,7 @@ class TorchCompileModelFluxAdvancedV2:
|
||||
try:
|
||||
if double_blocks:
|
||||
for i, block in enumerate(diffusion_model.double_blocks):
|
||||
print("Adding double block to compile list", i)
|
||||
compile_key_list.append(f"diffusion_model.double_blocks.{i}")
|
||||
if single_blocks:
|
||||
for i, block in enumerate(diffusion_model.single_blocks):
|
||||
@ -718,7 +725,7 @@ class TorchCompileModelHyVideo:
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
|
||||
DEPRECATED = True
|
||||
CATEGORY = "KJNodes/torchcompile"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
@ -850,7 +857,60 @@ class TorchCompileModelWanVideoV2:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
|
||||
class TorchCompileModelAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL",),
|
||||
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
|
||||
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
|
||||
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
|
||||
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
|
||||
"compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
|
||||
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
|
||||
"debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}),
|
||||
},
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
CATEGORY = "KJNodes/torchcompile"
|
||||
DESCRIPTION = "Advanced torch.compile patching for diffusion models."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys):
|
||||
from comfy_api.torch_helpers import set_torch_compile_wrapper
|
||||
m = model.clone()
|
||||
diffusion_model = m.get_model_object("diffusion_model")
|
||||
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
|
||||
|
||||
try:
|
||||
if compile_transformer_blocks_only:
|
||||
layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks", "visual_transformer_blocks", "text_transformer_blocks"]
|
||||
compile_key_list = []
|
||||
for layer_name in layer_types:
|
||||
if hasattr(diffusion_model, layer_name):
|
||||
blocks = getattr(diffusion_model, layer_name)
|
||||
for i in range(len(blocks)):
|
||||
compile_key_list.append(f"diffusion_model.{layer_name}.{i}")
|
||||
if not compile_key_list:
|
||||
logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead")
|
||||
elif debug_compile_keys:
|
||||
logging.info("TorchCompileModelAdvanced: Compile key list:")
|
||||
for key in compile_key_list:
|
||||
logging.info(f" - {key}")
|
||||
if not compile_key_list:
|
||||
compile_key_list =["diffusion_model"]
|
||||
|
||||
set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)
|
||||
except:
|
||||
raise RuntimeError("Failed to compile model")
|
||||
|
||||
return (m, )
|
||||
|
||||
|
||||
class TorchCompileModelQwenImage:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
@ -2005,3 +2065,241 @@ else:
|
||||
FUNCTION = ""
|
||||
CATEGORY = ""
|
||||
DESCRIPTION = "This node requires newer ComfyUI"
|
||||
|
||||
|
||||
try:
|
||||
from torch.nn.attention.flex_attention import flex_attention, BlockMask
|
||||
except:
|
||||
flex_attention = None
|
||||
BlockMask = None
|
||||
|
||||
class NABLA_AttentionKJ():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL",),
|
||||
"latent": ("LATENT", {"tooltip": "Only used to get the latent shape"}),
|
||||
"window_time": ("INT", {"default": 11, "min": 1, "tooltip": "Temporal attention window size"}),
|
||||
"window_width": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
|
||||
"window_height": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
|
||||
"sparsity": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"torch_compile": ("BOOLEAN", {"default": True, "tooltip": "Most likely required for reasonable memory usage"})
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL", )
|
||||
FUNCTION = "patch"
|
||||
DESCRIPTION = "Experimental node for patching attention mode to use NABLA sparse attention for video models, currently only works with Kadinsky5"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, latent, window_time, window_width, window_height, sparsity, torch_compile):
|
||||
if flex_attention is None or BlockMask is None:
|
||||
raise RuntimeError("can't import flex_attention from torch.nn.attention, requires newer pytorch version")
|
||||
|
||||
model_clone = model.clone()
|
||||
samples = latent["samples"]
|
||||
|
||||
sparse_params = get_sparse_params(samples, window_time, window_height, window_width, sparsity)
|
||||
nabla_attention = NABLA_Attention(sparse_params)
|
||||
|
||||
def attention_override_nabla(func, *args, **kwargs):
|
||||
return nabla_attention(*args, **kwargs)
|
||||
|
||||
if torch_compile:
|
||||
attention_override_nabla = torch.compile(attention_override_nabla, mode="max-autotune-no-cudagraphs", dynamic=True)
|
||||
|
||||
# attention override
|
||||
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_nabla
|
||||
|
||||
return model_clone,
|
||||
|
||||
|
||||
class NABLA_Attention():
|
||||
def __init__(self, sparse_params):
|
||||
self.sparse_params = sparse_params
|
||||
|
||||
def __call__(self, q, k, v, heads, **kwargs):
|
||||
if q.shape[-2] < 3000 or k.shape[-2] < 3000:
|
||||
return optimized_attention(q, k, v, heads, **kwargs)
|
||||
block_mask = self.nablaT_v2(q, k, self.sparse_params["sta_mask"], thr=self.sparse_params["P"])
|
||||
out = flex_attention(q, k, v, block_mask=block_mask).transpose(1, 2).contiguous().flatten(-2, -1)
|
||||
return out
|
||||
|
||||
def nablaT_v2(self, q, k, sta, thr=0.9):
|
||||
# Map estimation
|
||||
BLOCK_SIZE = 64
|
||||
B, h, S, D = q.shape
|
||||
s1 = S // BLOCK_SIZE
|
||||
qa = q.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2)
|
||||
ka = k.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2).transpose(-2, -1)
|
||||
map = qa @ ka
|
||||
|
||||
map = torch.softmax(map / math.sqrt(D), dim=-1)
|
||||
# Map binarization
|
||||
vals, inds = map.sort(-1)
|
||||
cvals = vals.cumsum_(-1)
|
||||
mask = (cvals >= 1 - thr).int()
|
||||
mask = mask.gather(-1, inds.argsort(-1))
|
||||
|
||||
mask = torch.logical_or(mask, sta)
|
||||
|
||||
# BlockMask creation
|
||||
kv_nb = mask.sum(-1).to(torch.int32)
|
||||
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
|
||||
return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=BLOCK_SIZE, mask_mod=None)
|
||||
|
||||
def fast_sta_nabla(T, H, W, wT=3, wH=3, wW=3):
|
||||
l = torch.Tensor([T, H, W]).amax()
|
||||
r = torch.arange(0, l, 1, dtype=torch.int16, device=mm.get_torch_device())
|
||||
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
|
||||
sta_t, sta_h, sta_w = (
|
||||
mat[:T, :T].flatten(),
|
||||
mat[:H, :H].flatten(),
|
||||
mat[:W, :W].flatten(),
|
||||
)
|
||||
sta_t = sta_t <= wT // 2
|
||||
sta_h = sta_h <= wH // 2
|
||||
sta_w = sta_w <= wW // 2
|
||||
sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
|
||||
sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
|
||||
return sta.reshape(T * H * W, T * H * W)
|
||||
|
||||
|
||||
def get_sparse_params(x, wT, wH, wW, sparsity=0.9):
|
||||
B, C, T, H, W = x.shape
|
||||
#print("x shape:", x.shape)
|
||||
patch_size = (1, 2, 2)
|
||||
T, H, W = (
|
||||
T // patch_size[0],
|
||||
H // patch_size[1],
|
||||
W // patch_size[2],
|
||||
)
|
||||
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW)
|
||||
sparse_params = {
|
||||
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
|
||||
"to_fractal": True,
|
||||
"P": sparsity,
|
||||
"wT": wT,
|
||||
"wH": wH,
|
||||
"wW": wW,
|
||||
"add_sta": True,
|
||||
"visual_shape": (T, H, W),
|
||||
"method": "topcdf",
|
||||
}
|
||||
|
||||
return sparse_params
|
||||
|
||||
from comfy.comfy_types.node_typing import IO
|
||||
class StartRecordCUDAMemoryHistory():
|
||||
# @classmethod
|
||||
# def IS_CHANGED(s):
|
||||
# return True
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"input": (IO.ANY,),
|
||||
"enabled": (["all", "state", "None"], {"default": "all", "tooltip": "None: disable, 'state': keep info for allocated memory, 'all': keep history of all alloc/free calls"}),
|
||||
"context": (["all", "state", "alloc", "None"], {"default": "all", "tooltip": "None: no tracebacks, 'state': tracebacks for allocated memory, 'alloc': for alloc calls, 'all': for free calls"}),
|
||||
"stacks": (["python", "all"], {"default": "all", "tooltip": "'python': Python/TorchScript/inductor frames, 'all': also C++ frames"}),
|
||||
"max_entries": ("INT", {"default": 100000, "min": 1000, "max": 10000000, "tooltip": "Maximum number of entries to record"}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY, )
|
||||
RETURN_NAMES = ("input", "output_path",)
|
||||
FUNCTION = "start"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "THIS NODE ALWAYS RUNS. Starts recording CUDA memory allocation history, can be ended and saved with EndRecordCUDAMemoryHistory. "
|
||||
|
||||
def start(self, input, enabled, context, stacks, max_entries):
|
||||
mm.soft_empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats(mm.get_torch_device())
|
||||
torch.cuda.memory._record_memory_history(
|
||||
max_entries=max_entries,
|
||||
enabled=enabled if enabled != "None" else None,
|
||||
context=context if context != "None" else None,
|
||||
stacks=stacks
|
||||
)
|
||||
return input,
|
||||
|
||||
class EndRecordCUDAMemoryHistory():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"input": (IO.ANY,),
|
||||
"output_path": ("STRING", {"default": "comfy_cuda_memory_history"}, "Base path for saving the CUDA memory history file, timestamp and .pt extension will be added"),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = (IO.ANY, "STRING",)
|
||||
RETURN_NAMES = ("input", "output_path",)
|
||||
FUNCTION = "end"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "Records CUDA memory allocation history between start and end, saves to a file that can be analyzed here: https://docs.pytorch.org/memory_viz or with VisualizeCUDAMemoryHistory node"
|
||||
|
||||
def end(self, input, output_path):
|
||||
mm.soft_empty_cache()
|
||||
time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = f"{output_path}{time}.pt"
|
||||
torch.cuda.memory._dump_snapshot(output_path)
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
return input, output_path
|
||||
|
||||
|
||||
try:
|
||||
from server import PromptServer
|
||||
except:
|
||||
PromptServer = None
|
||||
|
||||
class VisualizeCUDAMemoryHistory():
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"snapshot_path": ("STRING", ),
|
||||
},
|
||||
"hidden": {
|
||||
"unique_id": "UNIQUE_ID",
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("STRING",)
|
||||
RETURN_NAMES = ("output_path",)
|
||||
FUNCTION = "visualize"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
DESCRIPTION = "Visualizes a CUDA memory allocation history file, opens in browser"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
def visualize(self, snapshot_path, unique_id):
|
||||
import pickle
|
||||
from torch.cuda import _memory_viz
|
||||
import uuid
|
||||
|
||||
from folder_paths import get_output_directory
|
||||
output_dir = get_output_directory()
|
||||
|
||||
with open(snapshot_path, "rb") as f:
|
||||
snapshot = pickle.load(f)
|
||||
|
||||
html = _memory_viz.trace_plot(snapshot)
|
||||
html_filename = f"cuda_memory_history_{uuid.uuid4().hex}.html"
|
||||
output_path = os.path.join(output_dir, "memory_history", html_filename)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(html)
|
||||
|
||||
api_url = f"http://localhost:8188/api/view?type=output&filename={html_filename}&subfolder=memory_history"
|
||||
|
||||
# Progress UI
|
||||
if unique_id and PromptServer is not None:
|
||||
try:
|
||||
PromptServer.instance.send_progress_text(
|
||||
api_url,
|
||||
unique_id
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
return api_url,
|
||||
|
||||
347
nodes/nodes.py
347
nodes/nodes.py
@ -674,7 +674,7 @@ Converts any type to a string.
|
||||
elif isinstance(input, list):
|
||||
stringified = ', '.join(str(item) for item in input)
|
||||
else:
|
||||
return
|
||||
return input,
|
||||
if prefix: # Check if prefix is not empty
|
||||
stringified = prefix + stringified # Add the prefix
|
||||
if suffix: # Check if suffix is not empty
|
||||
@ -804,6 +804,7 @@ The choices are loaded from 'custom_dimensions.json' in the nodes folder.
|
||||
|
||||
return (latent, int(width), int(height),)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
class WidgetToString:
|
||||
@classmethod
|
||||
def IS_CHANGED(cls,*,id,node_title,any_input,**kwargs):
|
||||
@ -819,11 +820,11 @@ class WidgetToString:
|
||||
"return_all": ("BOOLEAN", {"default": False}),
|
||||
},
|
||||
"optional": {
|
||||
"any_input": (IO.ANY, ),
|
||||
"node_title": ("STRING", {"multiline": False}),
|
||||
"allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
|
||||
|
||||
},
|
||||
"any_input": (IO.ANY, ),
|
||||
"node_title": ("STRING", {"multiline": False}),
|
||||
"allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
|
||||
|
||||
},
|
||||
"hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
|
||||
"prompt": "PROMPT",
|
||||
"unique_id": "UNIQUE_ID",},
|
||||
@ -842,54 +843,231 @@ The 'any_input' is required for making sure the node you want the value from exi
|
||||
"""
|
||||
|
||||
def get_widget_value(self, id, widget_name, extra_pnginfo, prompt, unique_id, return_all=False, any_input=None, node_title="", allowed_float_decimals=2):
|
||||
"""
|
||||
Retrieves the value of the specified widget from a node in the workflow and
|
||||
returns it as a string.
|
||||
|
||||
If no `id` or `node_title` is provided, the method attempts to identify the
|
||||
node using the `any_input` connection in the workflow. Enable node ID display
|
||||
in ComfyUI's "Manager" menu to view node IDs, or use a manually edited node
|
||||
title for searching. NOTE: A node does not have a title unless it is manually
|
||||
edited to something other than its default value.
|
||||
|
||||
Args:
|
||||
id (int): The unique ID of the target node. If 0, the method relies on
|
||||
other methods to determine the node. TODO: change to a STRING (breaking change)
|
||||
widget_name (str): The name of the widget whose value needs to be retrieved.
|
||||
extra_pnginfo (dict): A dictionary containing workflow metadata, including
|
||||
node connections and state.
|
||||
prompt (dict): A dictionary containing node-specific data with input
|
||||
settings to extract widget values.
|
||||
unique_id (str): The unique identifier of the current node instance, used
|
||||
to match the `any_input` connection.
|
||||
return_all (bool): If True, retrieves and returns all input values from
|
||||
the node, formatted as a string.
|
||||
any_input (str): Optional. A link reference used to determine the node if
|
||||
no `id` or `node_title` is provided.
|
||||
node_title (str): Optional. The title of the node to search for. Titles
|
||||
are valid only if manually assigned in ComfyUI.
|
||||
allowed_float_decimals (int): The number of decimal places to which float
|
||||
values should be rounded in the output.
|
||||
|
||||
Returns:
|
||||
str or tuple:
|
||||
- If `return_all` is False, returns a tuple with the value of the
|
||||
specified widget.
|
||||
- If `return_all` is True, returns a formatted string containing all
|
||||
input values for the node.
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching node is found for the given `id`, `node_title`,
|
||||
or `any_input`.
|
||||
NameError: If the specified widget does not exist in the identified node.
|
||||
"""
|
||||
workflow = extra_pnginfo["workflow"]
|
||||
#print(json.dumps(workflow, indent=4))
|
||||
results = []
|
||||
node_id = None # Initialize node_id to handle cases where no match is found
|
||||
link_id = None
|
||||
target_full_node_id = None # string like "5", "5:1", "5:9:6"
|
||||
active_link_id = None
|
||||
|
||||
# Normalize incoming ids which may be lists/tuples (e.g., ["7:9:14", 0])
|
||||
def normalize_any_id(value):
|
||||
# If list/tuple, take the first element which should be the id/path
|
||||
if isinstance(value, (list, tuple)) and value:
|
||||
value = value[0]
|
||||
# Convert ints to str
|
||||
if isinstance(value, int):
|
||||
return str(value)
|
||||
# Pass through strings; None -> empty
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
id_str = normalize_any_id(id)
|
||||
unique_id_str = normalize_any_id(unique_id)
|
||||
|
||||
# Map of (scope_key, link_id) -> full_node_id
|
||||
# scope_key: '' for top-level, or the subgraph instance path for nested nodes (e.g., '5', '5:9')
|
||||
link_to_node_map = {}
|
||||
|
||||
for node in workflow["nodes"]:
|
||||
if node_title:
|
||||
if "title" in node:
|
||||
if node["title"] == node_title:
|
||||
node_id = node["id"]
|
||||
break
|
||||
else:
|
||||
print("Node title not found.")
|
||||
elif id != 0:
|
||||
if node["id"] == id:
|
||||
node_id = id
|
||||
break
|
||||
elif any_input is not None:
|
||||
if node["type"] == "WidgetToString" and node["id"] == int(unique_id) and not link_id:
|
||||
for node_input in node["inputs"]:
|
||||
if node_input["name"] == "any_input":
|
||||
link_id = node_input["link"]
|
||||
|
||||
# Construct a map of links to node IDs for future reference
|
||||
node_outputs = node.get("outputs", None)
|
||||
if not node_outputs:
|
||||
# Build a map of subgraph id -> definition for quick lookup
|
||||
defs = workflow.get("definitions", {}) or {}
|
||||
subgraph_defs = {sg.get("id"): sg for sg in (defs.get("subgraphs", []) or []) if sg.get("id")}
|
||||
|
||||
# Helper: register output links -> node map (scoped)
|
||||
def register_links(scope_key, node_obj, full_node_id):
|
||||
outputs = node_obj.get("outputs") or []
|
||||
for out in outputs:
|
||||
links = out.get("links")
|
||||
if not links:
|
||||
continue
|
||||
for output in node_outputs:
|
||||
node_links = output.get("links", None)
|
||||
if not node_links:
|
||||
continue
|
||||
for link in node_links:
|
||||
link_to_node_map[link] = node["id"]
|
||||
if link_id and link == link_id:
|
||||
break
|
||||
|
||||
if link_id:
|
||||
node_id = link_to_node_map.get(link_id, None)
|
||||
if isinstance(links, list):
|
||||
for lid in links:
|
||||
if lid is None:
|
||||
continue
|
||||
link_to_node_map[(scope_key, lid)] = full_node_id
|
||||
|
||||
if node_id is None:
|
||||
raise ValueError("No matching node found for the given title or id")
|
||||
# Recursive emitter for a subgraph instance
|
||||
# instance_path: the full path to this subgraph instance (e.g., '5' or '5:9')
|
||||
def emit_subgraph_instance(sub_def, instance_path):
|
||||
for snode in (sub_def.get("nodes") or []):
|
||||
child_id = str(snode.get("id"))
|
||||
full_id = f"{instance_path}:{child_id}"
|
||||
# Yield the node with the scope of this subgraph instance
|
||||
yield full_id, instance_path, snode
|
||||
# If this node itself is a subgraph instance, recurse
|
||||
stype = snode.get("type")
|
||||
nested_def = subgraph_defs.get(stype)
|
||||
if nested_def is not None:
|
||||
nested_instance_path = full_id # e.g., '5:9'
|
||||
for inner in emit_subgraph_instance(nested_def, nested_instance_path):
|
||||
yield inner
|
||||
|
||||
# Master iterator: yields all nodes with their full_node_id and scope
|
||||
def iter_all_nodes():
|
||||
# 1) Top-level nodes
|
||||
for node in workflow.get("nodes", []):
|
||||
full_node_id = str(node.get("id"))
|
||||
scope_key = "" # top-level link id space
|
||||
yield full_node_id, scope_key, node
|
||||
|
||||
# 2) If a top-level node is an instance of a subgraph, emit its internal nodes
|
||||
ntype = node.get("type")
|
||||
sg_def = subgraph_defs.get(ntype)
|
||||
if sg_def is not None:
|
||||
instance_path = full_node_id # e.g., '5'
|
||||
for item in emit_subgraph_instance(sg_def, instance_path):
|
||||
yield item
|
||||
|
||||
# Helpers for id/unique_id handling
|
||||
def match_id_with_fullness(candidate_full_id, requested_id):
|
||||
# Exact match if the request is fully qualified
|
||||
if ":" in requested_id:
|
||||
return candidate_full_id == requested_id
|
||||
# Otherwise, allow exact top-level id or ending with ":child"
|
||||
return candidate_full_id == requested_id or candidate_full_id.endswith(f":{requested_id}")
|
||||
|
||||
def parent_scope_of(full_id):
|
||||
parts = full_id.split(":")
|
||||
return ":".join(parts[:-1]) if len(parts) > 1 else ""
|
||||
|
||||
def resolve_scope_from_unique_id(u_str):
|
||||
# Fully qualified: everything before the last segment is the scope
|
||||
if ":" in u_str:
|
||||
return parent_scope_of(u_str)
|
||||
|
||||
# Not qualified: try to infer from prompt keys by suffix
|
||||
suffix = f":{u_str}"
|
||||
matches = [k for k in prompt.keys() if isinstance(k, str) and k.endswith(suffix)]
|
||||
matches = list(dict.fromkeys(matches)) # dedupe
|
||||
if len(matches) == 1:
|
||||
return parent_scope_of(matches[0])
|
||||
elif len(matches) == 0:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Ambiguous unique_id '{u_str}'. Multiple subgraph instances match. "
|
||||
f"Use a fully qualified id like 'parentPath:{u_str}' (e.g., '5:9:{u_str}')."
|
||||
)
|
||||
|
||||
# First: build a complete list of nodes and the scoped link map
|
||||
all_nodes = []
|
||||
for full_node_id, scope_key, node in iter_all_nodes():
|
||||
all_nodes.append((full_node_id, scope_key, node))
|
||||
register_links(scope_key, node, full_node_id)
|
||||
|
||||
# Try title or id first
|
||||
if node_title:
|
||||
for full_node_id, _, node in all_nodes:
|
||||
if "title" in node and node.get("title") == node_title:
|
||||
target_full_node_id = full_node_id
|
||||
break
|
||||
# If title matched, do not attempt any_input fallback
|
||||
any_input = None
|
||||
elif id_str not in ("", "0"):
|
||||
matches = [fid for fid, _, _ in all_nodes if match_id_with_fullness(fid, id_str)]
|
||||
if len(matches) > 1 and ":" not in id_str and any(m != id_str for m in matches):
|
||||
raise ValueError(
|
||||
f"Ambiguous id '{id_str}'. Multiple nodes match across (nested) subgraphs. "
|
||||
f"Use a fully qualified id like '5:9:{id_str}'."
|
||||
)
|
||||
target_full_node_id = matches[0] if matches else None
|
||||
|
||||
# Resolve via any_input + unique_id if still not found
|
||||
if target_full_node_id is None and any_input is not None and unique_id_str:
|
||||
# If unique_id is fully qualified, select that exact node
|
||||
wts_full_id = None
|
||||
if ":" in unique_id_str:
|
||||
for fid, _, node in all_nodes:
|
||||
if fid == unique_id_str and node.get("type") == "WidgetToString":
|
||||
wts_full_id = fid
|
||||
break
|
||||
if wts_full_id is None:
|
||||
raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'")
|
||||
found_scope_key = parent_scope_of(wts_full_id)
|
||||
else:
|
||||
# Infer scope from prompt keys when unqualified
|
||||
found_scope_key = resolve_scope_from_unique_id(unique_id_str)
|
||||
candidates = []
|
||||
if found_scope_key:
|
||||
candidates.append(f"{found_scope_key}:{unique_id_str}")
|
||||
else:
|
||||
candidates.append(unique_id_str)
|
||||
|
||||
for fid, scope_key, node in all_nodes:
|
||||
if node.get("type") == "WidgetToString" and fid in candidates:
|
||||
wts_full_id = fid
|
||||
if not found_scope_key:
|
||||
found_scope_key = parent_scope_of(fid)
|
||||
break
|
||||
|
||||
if wts_full_id is None:
|
||||
raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'")
|
||||
|
||||
# With the WidgetToString located, read its any_input link id
|
||||
wts_node = next(node for fid, _, node in all_nodes if fid == wts_full_id)
|
||||
for node_input in (wts_node.get("inputs") or []):
|
||||
if node_input.get("name") == "any_input":
|
||||
active_link_id = node_input.get("link")
|
||||
break
|
||||
|
||||
if active_link_id is None:
|
||||
raise ValueError(f"WidgetToString '{wts_full_id}' has no 'any_input' link")
|
||||
|
||||
# Resolve the producer of that link within the correct scope
|
||||
target_full_node_id = link_to_node_map.get((found_scope_key or "", active_link_id))
|
||||
if target_full_node_id is None:
|
||||
raise ValueError(
|
||||
f"Could not resolve link {active_link_id} in scope '{found_scope_key}'. "
|
||||
f"The subgraph clone’s links may not have been discovered."
|
||||
)
|
||||
|
||||
if target_full_node_id is None:
|
||||
raise ValueError("No matching node found for the given title, id, or any_input")
|
||||
|
||||
values = prompt.get(str(target_full_node_id))
|
||||
if not values:
|
||||
raise ValueError(f"No prompt entry found for node id: {target_full_node_id}")
|
||||
|
||||
values = prompt[str(node_id)]
|
||||
if "inputs" in values:
|
||||
if return_all:
|
||||
# Format items based on type
|
||||
formatted_items = []
|
||||
for k, v in values["inputs"].items():
|
||||
if isinstance(v, float):
|
||||
@ -906,7 +1084,7 @@ The 'any_input' is required for making sure the node you want the value from exi
|
||||
v = str(v)
|
||||
return (v, )
|
||||
else:
|
||||
raise NameError(f"Widget not found: {node_id}.{widget_name}")
|
||||
raise NameError(f"Widget not found: {target_full_node_id}.{widget_name}")
|
||||
return (', '.join(results).strip(', '), )
|
||||
|
||||
class DummyOut:
|
||||
@ -2623,3 +2801,80 @@ class LazySwitchKJ:
|
||||
def switch(self, switch, on_false = None, on_true=None):
|
||||
value = on_true if switch else on_false
|
||||
return (value,)
|
||||
|
||||
|
||||
from comfy.patcher_extension import WrappersMP
|
||||
from comfy.sampler_helpers import prepare_mask
|
||||
class TTM_SampleWrapper:
|
||||
def __init__(self, mask, steps):
|
||||
self.mask = mask
|
||||
self.steps = steps
|
||||
|
||||
def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar):
|
||||
model_options = extra_args["model_options"]
|
||||
wrappers = model_options["transformer_options"]["wrappers"]
|
||||
w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})
|
||||
|
||||
if self.mask is not None:
|
||||
motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
|
||||
motion_mask = prepare_mask(motion_mask, noise.shape, noise.device)
|
||||
|
||||
scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
|
||||
w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.steps, scale_latent_inpaint)]
|
||||
|
||||
out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class TTM_ApplyModel_Wrapper:
|
||||
def __init__(self, reference_samples, noise, motion_mask, steps, scale_latent_inpaint):
|
||||
self.reference_samples = reference_samples
|
||||
self.noise = noise
|
||||
self.motion_mask = motion_mask
|
||||
self.steps = steps
|
||||
self.scale_latent_inpaint = scale_latent_inpaint
|
||||
|
||||
def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
|
||||
sigmas = transformer_options["sample_sigmas"]
|
||||
|
||||
matched = (sigmas == t).nonzero(as_tuple=True)[0]
|
||||
if matched.numel() > 0:
|
||||
current_step_index = matched.item()
|
||||
else:
|
||||
crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
|
||||
current_step_index = crossing.item() if crossing.numel() > 0 else 0
|
||||
|
||||
next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]
|
||||
|
||||
if current_step_index != 0 and current_step_index < self.steps:
|
||||
noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
|
||||
if self.motion_mask is not None:
|
||||
x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
|
||||
else:
|
||||
x = noisy_latent
|
||||
|
||||
return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
|
||||
class LatentInpaintTTM:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"model": ("MODEL", ),
|
||||
"steps": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1, "tooltip": "Number of steps to apply TTM inpainting for."}),
|
||||
},
|
||||
"optional": {
|
||||
"mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "patch"
|
||||
EXPERIMENTAL = True
|
||||
DESCRIPTION = "https://github.com/time-to-move/TTM"
|
||||
CATEGORY = "KJNodes/experimental"
|
||||
|
||||
def patch(self, model, steps, mask=None):
|
||||
m = model.clone()
|
||||
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps))
|
||||
return (m, )
|
||||
Loading…
x
Reference in New Issue
Block a user