Compare commits

...

14 Commits

Author SHA1 Message Date
sfinktah
de6031f606
Merge 363c90f4e3b723bdd1a20e105f1aae0cff1043ed into 50e7dd34d3b6e6bbab1d41e8068e1ddd19bd4d1b 2025-12-04 00:46:57 -06:00
kijai
50e7dd34d3 Update model_optimization_nodes.py 2025-12-03 17:18:13 +02:00
kijai
37206374ef Add nodes to assist with CUDA memory use visualization 2025-12-03 17:13:44 +02:00
Jukka Seppänen
06a60ac3fe
Merge pull request #450 from m-sokes/patch-1
Update image_nodes.py
2025-12-01 11:13:51 +02:00
Sokes
a7ce03e735
Update image_nodes.py
sorted() is needed around os.listdir for proper linux file sorting.

The reason your files are not in order is that os.listdir() returns filenames in an arbitrary order (usually based on how they are stored in the file system's inode table), not alphabetically or numerically.
On Windows, os.listdir sometimes appears sorted due to how NTFS works, but on Ubuntu (Linux), the raw directory listing is almost never sorted by name.
The Fix
You need to sort the list of files before iterating through them.
Change this line:
code
Python
for f in os.listdir(kwargs['video']):
To this:
code
Python
for f in sorted(os.listdir(kwargs['video'])):
2025-11-30 18:35:29 -05:00
kijai
f37df472df Kandinsky5 blocks for compile too 2025-11-27 17:57:52 +02:00
kijai
390d05fe7e Add generic TorchCompileModelAdvanced node to handle advanced compile options for all diffusion models
Avoids needing different nodes for different models
2025-11-27 13:59:31 +02:00
kijai
f0ed965cd9 Allow fp32 input for sageattn function 2025-11-27 13:33:41 +02:00
kijai
acdd16a973 Add NABLA_AttentionKJ
Only tested with Kadinsky5
2025-11-26 23:40:12 +02:00
kijai
4dfb85dcc5 Update nodes.py 2025-11-23 02:18:47 +02:00
kijai
8660778ea1 TTM: Rename end_steps to steps for clarity 2025-11-23 02:10:23 +02:00
kijai
3b9c1b49ab Add LatentInpaintTTM
Can be used to mimic:
https://github.com/time-to-move/TTM
2025-11-23 01:47:56 +02:00
kijai
246920d8b9 Update nodes.py 2025-11-22 17:46:28 +02:00
Christopher Anderson
363c90f4e3 Added subgraph support for Widget2Str 2025-09-24 17:50:24 +10:00
4 changed files with 613 additions and 54 deletions

View File

@ -209,6 +209,12 @@ NODE_CONFIG = {
"ModelPatchTorchSettings": {"class": ModelPatchTorchSettings, "name": "Model Patch Torch Settings"},
"WanVideoNAG": {"class": WanVideoNAG, "name": "WanVideoNAG"},
"GGUFLoaderKJ": {"class": GGUFLoaderKJ, "name": "GGUF Loader KJ"},
"LatentInpaintTTM": {"class": LatentInpaintTTM, "name": "Latent Inpaint TTM"},
"NABLA_AttentionKJ": {"class": NABLA_AttentionKJ, "name": "NABLA Attention KJ"},
"TorchCompileModelAdvanced": {"class": TorchCompileModelAdvanced, "name": "TorchCompileModelAdvanced"},
"StartRecordCUDAMemoryHistory": {"class": StartRecordCUDAMemoryHistory, "name": "Start Recording CUDAMemory History"},
"EndRecordCUDAMemoryHistory": {"class": EndRecordCUDAMemoryHistory, "name": "End Recording CUDAMemory History"},
"VisualizeCUDAMemoryHistory": {"class": VisualizeCUDAMemoryHistory, "name": "Visualize CUDAMemory History"},
#instance diffusion
"CreateInstanceDiffusionTracking": {"class": CreateInstanceDiffusionTracking},

View File

@ -176,7 +176,7 @@ Saves an image and mask as .PNG with the mask as the alpha channel.
def file_counter():
max_counter = 0
# Loop through the existing files
for existing_file in os.listdir(full_output_folder):
for existing_file in sorted(os.listdir(full_output_folder)):
# Check if the file matches the expected format
match = re.fullmatch(fr"{filename}_(\d+)_?\.[a-zA-Z0-9]+", existing_file)
if match:
@ -2981,7 +2981,7 @@ class LoadImagesFromFolderKJ:
except OSError:
pass
else:
for file in os.listdir(folder):
for file in sorted(os.listdir(folder)):
if any(file.lower().endswith(ext) for ext in valid_extensions):
path = os.path.join(folder, file)
try:
@ -3043,7 +3043,7 @@ class LoadImagesFromFolderKJ:
if any(file.lower().endswith(ext) for ext in valid_extensions):
image_paths.append(os.path.join(root, file))
else:
for file in os.listdir(folder):
for file in sorted(os.listdir(folder)):
if any(file.lower().endswith(ext) for ext in valid_extensions):
image_paths.append(os.path.join(folder, file))
@ -3964,7 +3964,7 @@ class LoadVideosFromFolder:
raise ImportError("This node requires ComfyUI-VideoHelperSuite to be installed.")
videos_list = []
filenames = []
for f in os.listdir(kwargs['video']):
for f in sorted(os.listdir(kwargs['video'])):
if os.path.isfile(os.path.join(kwargs['video'], f)):
file_parts = f.split('.')
if len(file_parts) > 1 and (file_parts[-1].lower() in ['webm', 'mp4', 'mkv', 'gif', 'mov']):

View File

@ -3,15 +3,18 @@ from comfy.ldm.modules import attention as comfy_attention
import logging
import torch
import importlib
import math
import datetime
import folder_paths
import comfy.model_management as mm
from comfy.cli_args import args
from comfy.ldm.modules.attention import wrap_attn
from comfy.ldm.modules.attention import wrap_attn, optimized_attention
import comfy.model_patcher
import comfy.utils
import comfy.sd
try:
from comfy_api.latest import io
v3_available = True
@ -71,6 +74,9 @@ def get_sage_func(sage_attention, allow_compile=False):
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
in_dtype = v.dtype
if q.dtype == torch.float32 or k.dtype == torch.float32 or v.dtype == torch.float32:
q, k, v = q.to(torch.float16), k.to(torch.float16), v.to(torch.float16)
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout="HND"
@ -89,7 +95,7 @@ def get_sage_func(sage_attention, allow_compile=False):
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
out = sage_func(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout).to(in_dtype)
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
@ -675,6 +681,7 @@ class TorchCompileModelFluxAdvancedV2:
try:
if double_blocks:
for i, block in enumerate(diffusion_model.double_blocks):
print("Adding double block to compile list", i)
compile_key_list.append(f"diffusion_model.double_blocks.{i}")
if single_blocks:
for i, block in enumerate(diffusion_model.single_blocks):
@ -718,7 +725,7 @@ class TorchCompileModelHyVideo:
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
DEPRECATED = True
CATEGORY = "KJNodes/torchcompile"
EXPERIMENTAL = True
@ -850,7 +857,60 @@ class TorchCompileModelWanVideoV2:
raise RuntimeError("Failed to compile model")
return (m, )
class TorchCompileModelAdvanced:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"backend": (["inductor","cudagraphs"], {"default": "inductor"}),
"fullgraph": ("BOOLEAN", {"default": False, "tooltip": "Enable full graph mode"}),
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"compile_transformer_blocks_only": ("BOOLEAN", {"default": True, "tooltip": "Compile only transformer blocks, faster compile and less error prone"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
"debug_compile_keys": ("BOOLEAN", {"default": False, "tooltip": "Print the compile keys used for torch.compile"}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "KJNodes/torchcompile"
DESCRIPTION = "Advanced torch.compile patching for diffusion models."
EXPERIMENTAL = True
def patch(self, model, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_transformer_blocks_only, debug_compile_keys):
from comfy_api.torch_helpers import set_torch_compile_wrapper
m = model.clone()
diffusion_model = m.get_model_object("diffusion_model")
torch._dynamo.config.cache_size_limit = dynamo_cache_size_limit
try:
if compile_transformer_blocks_only:
layer_types = ["double_blocks", "single_blocks", "layers", "transformer_blocks", "blocks", "visual_transformer_blocks", "text_transformer_blocks"]
compile_key_list = []
for layer_name in layer_types:
if hasattr(diffusion_model, layer_name):
blocks = getattr(diffusion_model, layer_name)
for i in range(len(blocks)):
compile_key_list.append(f"diffusion_model.{layer_name}.{i}")
if not compile_key_list:
logging.warning("No known transformer blocks found to compile, compiling entire diffusion model instead")
elif debug_compile_keys:
logging.info("TorchCompileModelAdvanced: Compile key list:")
for key in compile_key_list:
logging.info(f" - {key}")
if not compile_key_list:
compile_key_list =["diffusion_model"]
set_torch_compile_wrapper(model=m, keys=compile_key_list, backend=backend, mode=mode, dynamic=dynamic, fullgraph=fullgraph)
except:
raise RuntimeError("Failed to compile model")
return (m, )
class TorchCompileModelQwenImage:
@classmethod
def INPUT_TYPES(s):
@ -2005,3 +2065,241 @@ else:
FUNCTION = ""
CATEGORY = ""
DESCRIPTION = "This node requires newer ComfyUI"
try:
from torch.nn.attention.flex_attention import flex_attention, BlockMask
except:
flex_attention = None
BlockMask = None
class NABLA_AttentionKJ():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL",),
"latent": ("LATENT", {"tooltip": "Only used to get the latent shape"}),
"window_time": ("INT", {"default": 11, "min": 1, "tooltip": "Temporal attention window size"}),
"window_width": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
"window_height": ("INT", {"default": 3, "min": 1, "tooltip": "Spatial attention window size"}),
"sparsity": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.01}),
"torch_compile": ("BOOLEAN", {"default": True, "tooltip": "Most likely required for reasonable memory usage"})
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "patch"
DESCRIPTION = "Experimental node for patching attention mode to use NABLA sparse attention for video models, currently only works with Kadinsky5"
CATEGORY = "KJNodes/experimental"
def patch(self, model, latent, window_time, window_width, window_height, sparsity, torch_compile):
if flex_attention is None or BlockMask is None:
raise RuntimeError("can't import flex_attention from torch.nn.attention, requires newer pytorch version")
model_clone = model.clone()
samples = latent["samples"]
sparse_params = get_sparse_params(samples, window_time, window_height, window_width, sparsity)
nabla_attention = NABLA_Attention(sparse_params)
def attention_override_nabla(func, *args, **kwargs):
return nabla_attention(*args, **kwargs)
if torch_compile:
attention_override_nabla = torch.compile(attention_override_nabla, mode="max-autotune-no-cudagraphs", dynamic=True)
# attention override
model_clone.model_options["transformer_options"]["optimized_attention_override"] = attention_override_nabla
return model_clone,
class NABLA_Attention():
def __init__(self, sparse_params):
self.sparse_params = sparse_params
def __call__(self, q, k, v, heads, **kwargs):
if q.shape[-2] < 3000 or k.shape[-2] < 3000:
return optimized_attention(q, k, v, heads, **kwargs)
block_mask = self.nablaT_v2(q, k, self.sparse_params["sta_mask"], thr=self.sparse_params["P"])
out = flex_attention(q, k, v, block_mask=block_mask).transpose(1, 2).contiguous().flatten(-2, -1)
return out
def nablaT_v2(self, q, k, sta, thr=0.9):
# Map estimation
BLOCK_SIZE = 64
B, h, S, D = q.shape
s1 = S // BLOCK_SIZE
qa = q.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2)
ka = k.reshape(B, h, s1, BLOCK_SIZE, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=BLOCK_SIZE, mask_mod=None)
def fast_sta_nabla(T, H, W, wT=3, wH=3, wW=3):
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=mm.get_torch_device())
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten()
sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2)
return sta.reshape(T * H * W, T * H * W)
def get_sparse_params(x, wT, wH, wW, sparsity=0.9):
B, C, T, H, W = x.shape
#print("x shape:", x.shape)
patch_size = (1, 2, 2)
T, H, W = (
T // patch_size[0],
H // patch_size[1],
W // patch_size[2],
)
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW)
sparse_params = {
"sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0),
"to_fractal": True,
"P": sparsity,
"wT": wT,
"wH": wH,
"wW": wW,
"add_sta": True,
"visual_shape": (T, H, W),
"method": "topcdf",
}
return sparse_params
from comfy.comfy_types.node_typing import IO
class StartRecordCUDAMemoryHistory():
# @classmethod
# def IS_CHANGED(s):
# return True
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"input": (IO.ANY,),
"enabled": (["all", "state", "None"], {"default": "all", "tooltip": "None: disable, 'state': keep info for allocated memory, 'all': keep history of all alloc/free calls"}),
"context": (["all", "state", "alloc", "None"], {"default": "all", "tooltip": "None: no tracebacks, 'state': tracebacks for allocated memory, 'alloc': for alloc calls, 'all': for free calls"}),
"stacks": (["python", "all"], {"default": "all", "tooltip": "'python': Python/TorchScript/inductor frames, 'all': also C++ frames"}),
"max_entries": ("INT", {"default": 100000, "min": 1000, "max": 10000000, "tooltip": "Maximum number of entries to record"}),
},
}
RETURN_TYPES = (IO.ANY, )
RETURN_NAMES = ("input", "output_path",)
FUNCTION = "start"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "THIS NODE ALWAYS RUNS. Starts recording CUDA memory allocation history, can be ended and saved with EndRecordCUDAMemoryHistory. "
def start(self, input, enabled, context, stacks, max_entries):
mm.soft_empty_cache()
torch.cuda.reset_peak_memory_stats(mm.get_torch_device())
torch.cuda.memory._record_memory_history(
max_entries=max_entries,
enabled=enabled if enabled != "None" else None,
context=context if context != "None" else None,
stacks=stacks
)
return input,
class EndRecordCUDAMemoryHistory():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"input": (IO.ANY,),
"output_path": ("STRING", {"default": "comfy_cuda_memory_history"}, "Base path for saving the CUDA memory history file, timestamp and .pt extension will be added"),
},
}
RETURN_TYPES = (IO.ANY, "STRING",)
RETURN_NAMES = ("input", "output_path",)
FUNCTION = "end"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "Records CUDA memory allocation history between start and end, saves to a file that can be analyzed here: https://docs.pytorch.org/memory_viz or with VisualizeCUDAMemoryHistory node"
def end(self, input, output_path):
mm.soft_empty_cache()
time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = f"{output_path}{time}.pt"
torch.cuda.memory._dump_snapshot(output_path)
torch.cuda.memory._record_memory_history(enabled=None)
return input, output_path
try:
from server import PromptServer
except:
PromptServer = None
class VisualizeCUDAMemoryHistory():
@classmethod
def INPUT_TYPES(s):
return {"required": {
"snapshot_path": ("STRING", ),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("output_path",)
FUNCTION = "visualize"
CATEGORY = "KJNodes/experimental"
DESCRIPTION = "Visualizes a CUDA memory allocation history file, opens in browser"
OUTPUT_NODE = True
def visualize(self, snapshot_path, unique_id):
import pickle
from torch.cuda import _memory_viz
import uuid
from folder_paths import get_output_directory
output_dir = get_output_directory()
with open(snapshot_path, "rb") as f:
snapshot = pickle.load(f)
html = _memory_viz.trace_plot(snapshot)
html_filename = f"cuda_memory_history_{uuid.uuid4().hex}.html"
output_path = os.path.join(output_dir, "memory_history", html_filename)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
f.write(html)
api_url = f"http://localhost:8188/api/view?type=output&filename={html_filename}&subfolder=memory_history"
# Progress UI
if unique_id and PromptServer is not None:
try:
PromptServer.instance.send_progress_text(
api_url,
unique_id
)
except:
pass
return api_url,

View File

@ -674,7 +674,7 @@ Converts any type to a string.
elif isinstance(input, list):
stringified = ', '.join(str(item) for item in input)
else:
return
return input,
if prefix: # Check if prefix is not empty
stringified = prefix + stringified # Add the prefix
if suffix: # Check if suffix is not empty
@ -804,6 +804,7 @@ The choices are loaded from 'custom_dimensions.json' in the nodes folder.
return (latent, int(width), int(height),)
# noinspection PyShadowingNames
class WidgetToString:
@classmethod
def IS_CHANGED(cls,*,id,node_title,any_input,**kwargs):
@ -819,11 +820,11 @@ class WidgetToString:
"return_all": ("BOOLEAN", {"default": False}),
},
"optional": {
"any_input": (IO.ANY, ),
"node_title": ("STRING", {"multiline": False}),
"allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
},
"any_input": (IO.ANY, ),
"node_title": ("STRING", {"multiline": False}),
"allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}),
},
"hidden": {"extra_pnginfo": "EXTRA_PNGINFO",
"prompt": "PROMPT",
"unique_id": "UNIQUE_ID",},
@ -842,54 +843,231 @@ The 'any_input' is required for making sure the node you want the value from exi
"""
def get_widget_value(self, id, widget_name, extra_pnginfo, prompt, unique_id, return_all=False, any_input=None, node_title="", allowed_float_decimals=2):
"""
Retrieves the value of the specified widget from a node in the workflow and
returns it as a string.
If no `id` or `node_title` is provided, the method attempts to identify the
node using the `any_input` connection in the workflow. Enable node ID display
in ComfyUI's "Manager" menu to view node IDs, or use a manually edited node
title for searching. NOTE: A node does not have a title unless it is manually
edited to something other than its default value.
Args:
id (int): The unique ID of the target node. If 0, the method relies on
other methods to determine the node. TODO: change to a STRING (breaking change)
widget_name (str): The name of the widget whose value needs to be retrieved.
extra_pnginfo (dict): A dictionary containing workflow metadata, including
node connections and state.
prompt (dict): A dictionary containing node-specific data with input
settings to extract widget values.
unique_id (str): The unique identifier of the current node instance, used
to match the `any_input` connection.
return_all (bool): If True, retrieves and returns all input values from
the node, formatted as a string.
any_input (str): Optional. A link reference used to determine the node if
no `id` or `node_title` is provided.
node_title (str): Optional. The title of the node to search for. Titles
are valid only if manually assigned in ComfyUI.
allowed_float_decimals (int): The number of decimal places to which float
values should be rounded in the output.
Returns:
str or tuple:
- If `return_all` is False, returns a tuple with the value of the
specified widget.
- If `return_all` is True, returns a formatted string containing all
input values for the node.
Raises:
ValueError: If no matching node is found for the given `id`, `node_title`,
or `any_input`.
NameError: If the specified widget does not exist in the identified node.
"""
workflow = extra_pnginfo["workflow"]
#print(json.dumps(workflow, indent=4))
results = []
node_id = None # Initialize node_id to handle cases where no match is found
link_id = None
target_full_node_id = None # string like "5", "5:1", "5:9:6"
active_link_id = None
# Normalize incoming ids which may be lists/tuples (e.g., ["7:9:14", 0])
def normalize_any_id(value):
# If list/tuple, take the first element which should be the id/path
if isinstance(value, (list, tuple)) and value:
value = value[0]
# Convert ints to str
if isinstance(value, int):
return str(value)
# Pass through strings; None -> empty
return value if isinstance(value, str) else ""
id_str = normalize_any_id(id)
unique_id_str = normalize_any_id(unique_id)
# Map of (scope_key, link_id) -> full_node_id
# scope_key: '' for top-level, or the subgraph instance path for nested nodes (e.g., '5', '5:9')
link_to_node_map = {}
for node in workflow["nodes"]:
if node_title:
if "title" in node:
if node["title"] == node_title:
node_id = node["id"]
break
else:
print("Node title not found.")
elif id != 0:
if node["id"] == id:
node_id = id
break
elif any_input is not None:
if node["type"] == "WidgetToString" and node["id"] == int(unique_id) and not link_id:
for node_input in node["inputs"]:
if node_input["name"] == "any_input":
link_id = node_input["link"]
# Construct a map of links to node IDs for future reference
node_outputs = node.get("outputs", None)
if not node_outputs:
# Build a map of subgraph id -> definition for quick lookup
defs = workflow.get("definitions", {}) or {}
subgraph_defs = {sg.get("id"): sg for sg in (defs.get("subgraphs", []) or []) if sg.get("id")}
# Helper: register output links -> node map (scoped)
def register_links(scope_key, node_obj, full_node_id):
outputs = node_obj.get("outputs") or []
for out in outputs:
links = out.get("links")
if not links:
continue
for output in node_outputs:
node_links = output.get("links", None)
if not node_links:
continue
for link in node_links:
link_to_node_map[link] = node["id"]
if link_id and link == link_id:
break
if link_id:
node_id = link_to_node_map.get(link_id, None)
if isinstance(links, list):
for lid in links:
if lid is None:
continue
link_to_node_map[(scope_key, lid)] = full_node_id
if node_id is None:
raise ValueError("No matching node found for the given title or id")
# Recursive emitter for a subgraph instance
# instance_path: the full path to this subgraph instance (e.g., '5' or '5:9')
def emit_subgraph_instance(sub_def, instance_path):
for snode in (sub_def.get("nodes") or []):
child_id = str(snode.get("id"))
full_id = f"{instance_path}:{child_id}"
# Yield the node with the scope of this subgraph instance
yield full_id, instance_path, snode
# If this node itself is a subgraph instance, recurse
stype = snode.get("type")
nested_def = subgraph_defs.get(stype)
if nested_def is not None:
nested_instance_path = full_id # e.g., '5:9'
for inner in emit_subgraph_instance(nested_def, nested_instance_path):
yield inner
# Master iterator: yields all nodes with their full_node_id and scope
def iter_all_nodes():
# 1) Top-level nodes
for node in workflow.get("nodes", []):
full_node_id = str(node.get("id"))
scope_key = "" # top-level link id space
yield full_node_id, scope_key, node
# 2) If a top-level node is an instance of a subgraph, emit its internal nodes
ntype = node.get("type")
sg_def = subgraph_defs.get(ntype)
if sg_def is not None:
instance_path = full_node_id # e.g., '5'
for item in emit_subgraph_instance(sg_def, instance_path):
yield item
# Helpers for id/unique_id handling
def match_id_with_fullness(candidate_full_id, requested_id):
# Exact match if the request is fully qualified
if ":" in requested_id:
return candidate_full_id == requested_id
# Otherwise, allow exact top-level id or ending with ":child"
return candidate_full_id == requested_id or candidate_full_id.endswith(f":{requested_id}")
def parent_scope_of(full_id):
parts = full_id.split(":")
return ":".join(parts[:-1]) if len(parts) > 1 else ""
def resolve_scope_from_unique_id(u_str):
# Fully qualified: everything before the last segment is the scope
if ":" in u_str:
return parent_scope_of(u_str)
# Not qualified: try to infer from prompt keys by suffix
suffix = f":{u_str}"
matches = [k for k in prompt.keys() if isinstance(k, str) and k.endswith(suffix)]
matches = list(dict.fromkeys(matches)) # dedupe
if len(matches) == 1:
return parent_scope_of(matches[0])
elif len(matches) == 0:
return None
else:
raise ValueError(
f"Ambiguous unique_id '{u_str}'. Multiple subgraph instances match. "
f"Use a fully qualified id like 'parentPath:{u_str}' (e.g., '5:9:{u_str}')."
)
# First: build a complete list of nodes and the scoped link map
all_nodes = []
for full_node_id, scope_key, node in iter_all_nodes():
all_nodes.append((full_node_id, scope_key, node))
register_links(scope_key, node, full_node_id)
# Try title or id first
if node_title:
for full_node_id, _, node in all_nodes:
if "title" in node and node.get("title") == node_title:
target_full_node_id = full_node_id
break
# If title matched, do not attempt any_input fallback
any_input = None
elif id_str not in ("", "0"):
matches = [fid for fid, _, _ in all_nodes if match_id_with_fullness(fid, id_str)]
if len(matches) > 1 and ":" not in id_str and any(m != id_str for m in matches):
raise ValueError(
f"Ambiguous id '{id_str}'. Multiple nodes match across (nested) subgraphs. "
f"Use a fully qualified id like '5:9:{id_str}'."
)
target_full_node_id = matches[0] if matches else None
# Resolve via any_input + unique_id if still not found
if target_full_node_id is None and any_input is not None and unique_id_str:
# If unique_id is fully qualified, select that exact node
wts_full_id = None
if ":" in unique_id_str:
for fid, _, node in all_nodes:
if fid == unique_id_str and node.get("type") == "WidgetToString":
wts_full_id = fid
break
if wts_full_id is None:
raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'")
found_scope_key = parent_scope_of(wts_full_id)
else:
# Infer scope from prompt keys when unqualified
found_scope_key = resolve_scope_from_unique_id(unique_id_str)
candidates = []
if found_scope_key:
candidates.append(f"{found_scope_key}:{unique_id_str}")
else:
candidates.append(unique_id_str)
for fid, scope_key, node in all_nodes:
if node.get("type") == "WidgetToString" and fid in candidates:
wts_full_id = fid
if not found_scope_key:
found_scope_key = parent_scope_of(fid)
break
if wts_full_id is None:
raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'")
# With the WidgetToString located, read its any_input link id
wts_node = next(node for fid, _, node in all_nodes if fid == wts_full_id)
for node_input in (wts_node.get("inputs") or []):
if node_input.get("name") == "any_input":
active_link_id = node_input.get("link")
break
if active_link_id is None:
raise ValueError(f"WidgetToString '{wts_full_id}' has no 'any_input' link")
# Resolve the producer of that link within the correct scope
target_full_node_id = link_to_node_map.get((found_scope_key or "", active_link_id))
if target_full_node_id is None:
raise ValueError(
f"Could not resolve link {active_link_id} in scope '{found_scope_key}'. "
f"The subgraph clones links may not have been discovered."
)
if target_full_node_id is None:
raise ValueError("No matching node found for the given title, id, or any_input")
values = prompt.get(str(target_full_node_id))
if not values:
raise ValueError(f"No prompt entry found for node id: {target_full_node_id}")
values = prompt[str(node_id)]
if "inputs" in values:
if return_all:
# Format items based on type
formatted_items = []
for k, v in values["inputs"].items():
if isinstance(v, float):
@ -906,7 +1084,7 @@ The 'any_input' is required for making sure the node you want the value from exi
v = str(v)
return (v, )
else:
raise NameError(f"Widget not found: {node_id}.{widget_name}")
raise NameError(f"Widget not found: {target_full_node_id}.{widget_name}")
return (', '.join(results).strip(', '), )
class DummyOut:
@ -2623,3 +2801,80 @@ class LazySwitchKJ:
def switch(self, switch, on_false = None, on_true=None):
value = on_true if switch else on_false
return (value,)
from comfy.patcher_extension import WrappersMP
from comfy.sampler_helpers import prepare_mask
class TTM_SampleWrapper:
def __init__(self, mask, steps):
self.mask = mask
self.steps = steps
def __call__(self, sampler, guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar):
model_options = extra_args["model_options"]
wrappers = model_options["transformer_options"]["wrappers"]
w = wrappers.setdefault(WrappersMP.APPLY_MODEL, {})
if self.mask is not None:
motion_mask = self.mask.reshape((-1, 1, self.mask.shape[-2], self.mask.shape[-1]))
motion_mask = prepare_mask(motion_mask, noise.shape, noise.device)
scale_latent_inpaint = guider.model_patcher.model.scale_latent_inpaint
w["TTM_ApplyModel_Wrapper"] = [TTM_ApplyModel_Wrapper(latent_image, noise, motion_mask, self.steps, scale_latent_inpaint)]
out = sampler(guider, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar)
return out
class TTM_ApplyModel_Wrapper:
def __init__(self, reference_samples, noise, motion_mask, steps, scale_latent_inpaint):
self.reference_samples = reference_samples
self.noise = noise
self.motion_mask = motion_mask
self.steps = steps
self.scale_latent_inpaint = scale_latent_inpaint
def __call__(self, executor, x, t, c_concat, c_crossattn, control, transformer_options, **kwargs):
sigmas = transformer_options["sample_sigmas"]
matched = (sigmas == t).nonzero(as_tuple=True)[0]
if matched.numel() > 0:
current_step_index = matched.item()
else:
crossing = ((sigmas[:-1] - t) * (sigmas[1:] - t) <= 0).nonzero(as_tuple=True)[0]
current_step_index = crossing.item() if crossing.numel() > 0 else 0
next_sigma = sigmas[current_step_index + 1] if current_step_index < len(sigmas) - 1 else sigmas[current_step_index]
if current_step_index != 0 and current_step_index < self.steps:
noisy_latent = self.scale_latent_inpaint(x=x, sigma=torch.tensor([next_sigma]), noise=self.noise.to(x), latent_image=self.reference_samples.to(x))
if self.motion_mask is not None:
x = x * (1-self.motion_mask).to(x) + noisy_latent * self.motion_mask.to(x)
else:
x = noisy_latent
return executor(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
class LatentInpaintTTM:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ("MODEL", ),
"steps": ("INT", {"default": 7, "min": 0, "max": 888, "step": 1, "tooltip": "Number of steps to apply TTM inpainting for."}),
},
"optional": {
"mask": ("MASK", {"tooltip": "Latent mask where white (1.0) is the area to inpaint and black (0.0) is the area to keep unchanged."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
EXPERIMENTAL = True
DESCRIPTION = "https://github.com/time-to-move/TTM"
CATEGORY = "KJNodes/experimental"
def patch(self, model, steps, mask=None):
m = model.clone()
m.add_wrapper_with_key(WrappersMP.SAMPLER_SAMPLE, "TTM_SampleWrapper", TTM_SampleWrapper(mask, steps))
return (m, )