From fbd6ba21825d81d160152abd5d0c4b70cd392973 Mon Sep 17 00:00:00 2001 From: Christopher Anderson Date: Mon, 15 Sep 2025 23:18:08 +1000 Subject: [PATCH] Added subgraph support for Widget2Str --- nodes/nodes.py | 914 +++++++++++++++++++++++++++++-------------------- 1 file changed, 546 insertions(+), 368 deletions(-) diff --git a/nodes/nodes.py b/nodes/nodes.py index afc793f..1c195ae 100644 --- a/nodes/nodes.py +++ b/nodes/nodes.py @@ -29,7 +29,7 @@ class BOOLConstant: def get_value(self, value): return (value,) - + class INTConstant: @classmethod def INPUT_TYPES(s): @@ -100,9 +100,9 @@ class StringConstantMultiline: return (new_string, ) - + class ScaleBatchPromptSchedule: - + RETURN_TYPES = ("STRING",) FUNCTION = "scaleschedule" CATEGORY = "KJNodes/misc" @@ -115,38 +115,38 @@ to a different frame count. def INPUT_TYPES(s): return { "required": { - "input_str": ("STRING", {"forceInput": True,"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n"}), - "old_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}), - "new_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}), - - }, - } - + "input_str": ("STRING", {"forceInput": True,"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n"}), + "old_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}), + "new_frame_count": ("INT", {"forceInput": True,"default": 1,"min": 1, "max": 4096, "step": 1}), + + }, + } + def scaleschedule(self, old_frame_count, input_str, new_frame_count): pattern = r'"(\d+)"\s*:\s*"(.*?)"(?:,|\Z)' frame_strings = dict(re.findall(pattern, input_str)) - + # Calculate the scaling factor scaling_factor = (new_frame_count - 1) / (old_frame_count - 1) - + # Initialize a dictionary to store the new frame numbers and strings new_frame_strings = {} - + # Iterate over the frame numbers and strings for old_frame, string in frame_strings.items(): # Calculate the new frame number new_frame = int(round(int(old_frame) * scaling_factor)) - + # Store the new frame number and corresponding string new_frame_strings[new_frame] = string - + # Format the output string output_str = ', '.join([f'"{k}":"{v}"' for k, v in sorted(new_frame_strings.items())]) return (output_str,) class GetLatentsFromBatchIndexed: - + RETURN_TYPES = ("LATENT",) FUNCTION = "indexedlatentsfrombatch" CATEGORY = "KJNodes/latents" @@ -158,23 +158,23 @@ Selects and returns the latents at the specified indices as an latent batch. def INPUT_TYPES(s): return { "required": { - "latents": ("LATENT",), - "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), - "latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}), - }, - } - + "latents": ("LATENT",), + "indexes": ("STRING", {"default": "0, 1, 2", "multiline": True}), + "latent_format": (["BCHW", "BTCHW", "BCTHW"], {"default": "BCHW"}), + }, + } + def indexedlatentsfrombatch(self, latents, indexes, latent_format): - + samples = latents.copy() - latent_samples = samples["samples"] + latent_samples = samples["samples"] # Parse the indexes string into a list of integers index_list = [int(index.strip()) for index in indexes.split(',')] - + # Convert list of indices to a PyTorch tensor indices_tensor = torch.tensor(index_list, dtype=torch.long) - + # Select the latents at the specified indices if latent_format == "BCHW": chosen_latents = latent_samples[indices_tensor] @@ -185,7 +185,7 @@ Selects and returns the latents at the specified indices as an latent batch. samples["samples"] = chosen_latents return (samples,) - + class ConditioningMultiCombine: @classmethod @@ -197,7 +197,7 @@ class ConditioningMultiCombine: "conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", ), }, - } + } RETURN_TYPES = ("CONDITIONING", "INT") RETURN_NAMES = ("combined", "inputcount") @@ -239,10 +239,10 @@ class AppendStringsToList: string1 = [string1] if not isinstance(string2, list): string2 = [string2] - + joined_string = string1 + string2 return (joined_string, ) - + class JoinStrings: @classmethod def INPUT_TYPES(cls): @@ -262,7 +262,7 @@ class JoinStrings: def joinstring(self, delimiter, string1="", string2=""): joined_string = string1 + delimiter + string2 return (joined_string, ) - + class JoinStringMulti: @classmethod def INPUT_TYPES(s): @@ -276,7 +276,7 @@ class JoinStringMulti: "optional": { "string_2": ("STRING", {"default": '', "forceInput": True}), } - } + } RETURN_TYPES = ("STRING",) RETURN_NAMES = ("string",) @@ -315,8 +315,8 @@ class CondPassThrough: "optional": { "positive": ("CONDITIONING", ), "negative": ("CONDITIONING", ), - }, - } + }, + } RETURN_TYPES = ("CONDITIONING", "CONDITIONING",) RETURN_NAMES = ("positive", "negative") @@ -334,12 +334,12 @@ class ModelPassThrough: @classmethod def INPUT_TYPES(s): return { - "required": { + "required": { }, "optional": { "model": ("MODEL", ), - }, - } + }, + } RETURN_TYPES = ("MODEL", ) RETURN_NAMES = ("model",) @@ -351,15 +351,15 @@ class ModelPassThrough: """ def passthrough(self, model=None): - return (model,) + return (model,) def append_helper(t, mask, c, set_area_to_bounds, strength): - n = [t[0], t[1].copy()] - _, h, w = mask.shape - n[1]['mask'] = mask - n[1]['set_area_to_bounds'] = set_area_to_bounds - n[1]['mask_strength'] = strength - c.append(n) + n = [t[0], t[1].copy()] + _, h, w = mask.shape + n[1]['mask'] = mask + n[1]['set_area_to_bounds'] = set_area_to_bounds + n[1]['mask_strength'] = strength + c.append(n) class ConditioningSetMaskAndCombine: @classmethod @@ -600,26 +600,26 @@ Bundles multiple conditioning mask and combine nodes into one,functionality is i for t in negative_5: append_helper(t, mask_5, c2, set_area_to_bounds, mask_5_strength) return (c, c2) - + class VRAM_Debug: - + @classmethod - + def INPUT_TYPES(s): - return { - "required": { - - "empty_cache": ("BOOLEAN", {"default": True}), - "gc_collect": ("BOOLEAN", {"default": True}), - "unload_all_models": ("BOOLEAN", {"default": False}), - }, - "optional": { - "any_input": (IO.ANY,), - "image_pass": ("IMAGE",), - "model_pass": ("MODEL",), + return { + "required": { + + "empty_cache": ("BOOLEAN", {"default": True}), + "gc_collect": ("BOOLEAN", {"default": True}), + "unload_all_models": ("BOOLEAN", {"default": False}), + }, + "optional": { + "any_input": (IO.ANY,), + "image_pass": ("IMAGE",), + "model_pass": ("MODEL",), + } } - } - + RETURN_TYPES = (IO.ANY, "IMAGE","MODEL","INT", "INT",) RETURN_NAMES = ("any_output", "image_pass", "model_pass", "freemem_before", "freemem_after") FUNCTION = "VRAMdebug" @@ -644,23 +644,23 @@ reports free VRAM before and after the operations. print("VRAMdebug: free memory after: ", f"{freemem_after:,.0f}") print("VRAMdebug: freed memory: ", f"{freemem_after - freemem_before:,.0f}") return {"ui": { - "text": [f"{freemem_before:,.0f}x{freemem_after:,.0f}"]}, - "result": (any_input, image_pass, model_pass, freemem_before, freemem_after) + "text": [f"{freemem_before:,.0f}x{freemem_after:,.0f}"]}, + "result": (any_input, image_pass, model_pass, freemem_before, freemem_after) } class SomethingToString: @classmethod - + def INPUT_TYPES(s): - return { - "required": { - "input": (IO.ANY, ), - }, - "optional": { - "prefix": ("STRING", {"default": ""}), - "suffix": ("STRING", {"default": ""}), - } - } + return { + "required": { + "input": (IO.ANY, ), + }, + "optional": { + "prefix": ("STRING", {"default": ""}), + "suffix": ("STRING", {"default": ""}), + } + } RETURN_TYPES = ("STRING",) FUNCTION = "stringify" CATEGORY = "KJNodes/text" @@ -703,36 +703,36 @@ Delays the execution for the input amount of time. total_seconds = minutes * 60 + seconds time.sleep(total_seconds) return input, - + class EmptyLatentImagePresets: @classmethod - def INPUT_TYPES(cls): + def INPUT_TYPES(cls): return { - "required": { - "dimensions": ( - [ - '512 x 512 (1:1)', - '768 x 512 (1.5:1)', - '960 x 512 (1.875:1)', - '1024 x 512 (2:1)', - '1024 x 576 (1.778:1)', - '1536 x 640 (2.4:1)', - '1344 x 768 (1.75:1)', - '1216 x 832 (1.46:1)', - '1152 x 896 (1.286:1)', - '1024 x 1024 (1:1)', - ], - { - "default": '512 x 512 (1:1)' - }), - - "invert": ("BOOLEAN", {"default": False}), - "batch_size": ("INT", { - "default": 1, - "min": 1, - "max": 4096 - }), - }, + "required": { + "dimensions": ( + [ + '512 x 512 (1:1)', + '768 x 512 (1.5:1)', + '960 x 512 (1.875:1)', + '1024 x 512 (2:1)', + '1024 x 576 (1.778:1)', + '1536 x 640 (2.4:1)', + '1344 x 768 (1.75:1)', + '1216 x 832 (1.46:1)', + '1152 x 896 (1.286:1)', + '1024 x 1024 (1:1)', + ], + { + "default": '512 x 512 (1:1)' + }), + + "invert": ("BOOLEAN", {"default": False}), + "batch_size": ("INT", { + "default": 1, + "min": 1, + "max": 4096 + }), + }, } RETURN_TYPES = ("LATENT", "INT", "INT") @@ -747,7 +747,7 @@ class EmptyLatentImagePresets: # Remove the aspect ratio part result[0] = result[0].split('(')[0].strip() result[1] = result[1].split('(')[0].strip() - + if invert: width = int(result[1].split(' ')[0]) height = int(result[0]) @@ -767,18 +767,18 @@ class EmptyLatentImageCustomPresets: except FileNotFoundError: dimensions_dict = [] return { - "required": { - "dimensions": ( - [f"{d['label']} - {d['value']}" for d in dimensions_dict], - ), - - "invert": ("BOOLEAN", {"default": False}), - "batch_size": ("INT", { - "default": 1, - "min": 1, - "max": 4096 - }), - }, + "required": { + "dimensions": ( + [f"{d['label']} - {d['value']}" for d in dimensions_dict], + ), + + "invert": ("BOOLEAN", {"default": False}), + "batch_size": ("INT", { + "default": 1, + "min": 1, + "max": 4096 + }), + }, } RETURN_TYPES = ("LATENT", "INT", "INT") @@ -791,19 +791,20 @@ The choices are loaded from 'custom_dimensions.json' in the nodes folder. """ def generate(self, dimensions, invert, batch_size): - from nodes import EmptyLatentImage - # Split the string into label and value - label, value = dimensions.split(' - ') - # Split the value into width and height - width, height = [x.strip() for x in value.split('x')] - - if invert: - width, height = height, width - - latent = EmptyLatentImage().generate(int(width), int(height), batch_size)[0] - - return (latent, int(width), int(height),) + from nodes import EmptyLatentImage + # Split the string into label and value + label, value = dimensions.split(' - ') + # Split the value into width and height + width, height = [x.strip() for x in value.split('x')] + if invert: + width, height = height, width + + latent = EmptyLatentImage().generate(int(width), int(height), batch_size)[0] + + return (latent, int(width), int(height),) + +# noinspection PyShadowingNames class WidgetToString: @classmethod def IS_CHANGED(cls,*,id,node_title,any_input,**kwargs): @@ -819,11 +820,11 @@ class WidgetToString: "return_all": ("BOOLEAN", {"default": False}), }, "optional": { - "any_input": (IO.ANY, ), - "node_title": ("STRING", {"multiline": False}), - "allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}), - - }, + "any_input": (IO.ANY, ), + "node_title": ("STRING", {"multiline": False}), + "allowed_float_decimals": ("INT", {"default": 2, "min": 0, "max": 10, "tooltip": "Number of decimal places to display for float values"}), + + }, "hidden": {"extra_pnginfo": "EXTRA_PNGINFO", "prompt": "PROMPT", "unique_id": "UNIQUE_ID",}, @@ -842,54 +843,231 @@ The 'any_input' is required for making sure the node you want the value from exi """ def get_widget_value(self, id, widget_name, extra_pnginfo, prompt, unique_id, return_all=False, any_input=None, node_title="", allowed_float_decimals=2): + """ + Retrieves the value of the specified widget from a node in the workflow and + returns it as a string. + + If no `id` or `node_title` is provided, the method attempts to identify the + node using the `any_input` connection in the workflow. Enable node ID display + in ComfyUI's "Manager" menu to view node IDs, or use a manually edited node + title for searching. NOTE: A node does not have a title unless it is manually + edited to something other than its default value. + + Args: + id (int): The unique ID of the target node. If 0, the method relies on + other methods to determine the node. TODO: change to a STRING (breaking change) + widget_name (str): The name of the widget whose value needs to be retrieved. + extra_pnginfo (dict): A dictionary containing workflow metadata, including + node connections and state. + prompt (dict): A dictionary containing node-specific data with input + settings to extract widget values. + unique_id (str): The unique identifier of the current node instance, used + to match the `any_input` connection. + return_all (bool): If True, retrieves and returns all input values from + the node, formatted as a string. + any_input (str): Optional. A link reference used to determine the node if + no `id` or `node_title` is provided. + node_title (str): Optional. The title of the node to search for. Titles + are valid only if manually assigned in ComfyUI. + allowed_float_decimals (int): The number of decimal places to which float + values should be rounded in the output. + + Returns: + str or tuple: + - If `return_all` is False, returns a tuple with the value of the + specified widget. + - If `return_all` is True, returns a formatted string containing all + input values for the node. + + Raises: + ValueError: If no matching node is found for the given `id`, `node_title`, + or `any_input`. + NameError: If the specified widget does not exist in the identified node. + """ workflow = extra_pnginfo["workflow"] - #print(json.dumps(workflow, indent=4)) results = [] - node_id = None # Initialize node_id to handle cases where no match is found - link_id = None + target_full_node_id = None # string like "5", "5:1", "5:9:6" + active_link_id = None + + # Normalize incoming ids which may be lists/tuples (e.g., ["7:9:14", 0]) + def normalize_any_id(value): + # If list/tuple, take the first element which should be the id/path + if isinstance(value, (list, tuple)) and value: + value = value[0] + # Convert ints to str + if isinstance(value, int): + return str(value) + # Pass through strings; None -> empty + return value if isinstance(value, str) else "" + + id_str = normalize_any_id(id) + unique_id_str = normalize_any_id(unique_id) + + # Map of (scope_key, link_id) -> full_node_id + # scope_key: '' for top-level, or the subgraph instance path for nested nodes (e.g., '5', '5:9') link_to_node_map = {} - for node in workflow["nodes"]: - if node_title: - if "title" in node: - if node["title"] == node_title: - node_id = node["id"] - break - else: - print("Node title not found.") - elif id != 0: - if node["id"] == id: - node_id = id - break - elif any_input is not None: - if node["type"] == "WidgetToString" and node["id"] == int(unique_id) and not link_id: - for node_input in node["inputs"]: - if node_input["name"] == "any_input": - link_id = node_input["link"] - - # Construct a map of links to node IDs for future reference - node_outputs = node.get("outputs", None) - if not node_outputs: + # Build a map of subgraph id -> definition for quick lookup + defs = workflow.get("definitions", {}) or {} + subgraph_defs = {sg.get("id"): sg for sg in (defs.get("subgraphs", []) or []) if sg.get("id")} + + # Helper: register output links -> node map (scoped) + def register_links(scope_key, node_obj, full_node_id): + outputs = node_obj.get("outputs") or [] + for out in outputs: + links = out.get("links") + if not links: continue - for output in node_outputs: - node_links = output.get("links", None) - if not node_links: - continue - for link in node_links: - link_to_node_map[link] = node["id"] - if link_id and link == link_id: - break - - if link_id: - node_id = link_to_node_map.get(link_id, None) + if isinstance(links, list): + for lid in links: + if lid is None: + continue + link_to_node_map[(scope_key, lid)] = full_node_id - if node_id is None: - raise ValueError("No matching node found for the given title or id") + # Recursive emitter for a subgraph instance + # instance_path: the full path to this subgraph instance (e.g., '5' or '5:9') + def emit_subgraph_instance(sub_def, instance_path): + for snode in (sub_def.get("nodes") or []): + child_id = str(snode.get("id")) + full_id = f"{instance_path}:{child_id}" + # Yield the node with the scope of this subgraph instance + yield full_id, instance_path, snode + # If this node itself is a subgraph instance, recurse + stype = snode.get("type") + nested_def = subgraph_defs.get(stype) + if nested_def is not None: + nested_instance_path = full_id # e.g., '5:9' + for inner in emit_subgraph_instance(nested_def, nested_instance_path): + yield inner + + # Master iterator: yields all nodes with their full_node_id and scope + def iter_all_nodes(): + # 1) Top-level nodes + for node in workflow.get("nodes", []): + full_node_id = str(node.get("id")) + scope_key = "" # top-level link id space + yield full_node_id, scope_key, node + + # 2) If a top-level node is an instance of a subgraph, emit its internal nodes + ntype = node.get("type") + sg_def = subgraph_defs.get(ntype) + if sg_def is not None: + instance_path = full_node_id # e.g., '5' + for item in emit_subgraph_instance(sg_def, instance_path): + yield item + + # Helpers for id/unique_id handling + def match_id_with_fullness(candidate_full_id, requested_id): + # Exact match if the request is fully qualified + if ":" in requested_id: + return candidate_full_id == requested_id + # Otherwise, allow exact top-level id or ending with ":child" + return candidate_full_id == requested_id or candidate_full_id.endswith(f":{requested_id}") + + def parent_scope_of(full_id): + parts = full_id.split(":") + return ":".join(parts[:-1]) if len(parts) > 1 else "" + + def resolve_scope_from_unique_id(u_str): + # Fully qualified: everything before the last segment is the scope + if ":" in u_str: + return parent_scope_of(u_str) + + # Not qualified: try to infer from prompt keys by suffix + suffix = f":{u_str}" + matches = [k for k in prompt.keys() if isinstance(k, str) and k.endswith(suffix)] + matches = list(dict.fromkeys(matches)) # dedupe + if len(matches) == 1: + return parent_scope_of(matches[0]) + elif len(matches) == 0: + return None + else: + raise ValueError( + f"Ambiguous unique_id '{u_str}'. Multiple subgraph instances match. " + f"Use a fully qualified id like 'parentPath:{u_str}' (e.g., '5:9:{u_str}')." + ) + + # First: build a complete list of nodes and the scoped link map + all_nodes = [] + for full_node_id, scope_key, node in iter_all_nodes(): + all_nodes.append((full_node_id, scope_key, node)) + register_links(scope_key, node, full_node_id) + + # Try title or id first + if node_title: + for full_node_id, _, node in all_nodes: + if "title" in node and node.get("title") == node_title: + target_full_node_id = full_node_id + break + # If title matched, do not attempt any_input fallback + any_input = None + elif id_str not in ("", "0"): + matches = [fid for fid, _, _ in all_nodes if match_id_with_fullness(fid, id_str)] + if len(matches) > 1 and ":" not in id_str and any(m != id_str for m in matches): + raise ValueError( + f"Ambiguous id '{id_str}'. Multiple nodes match across (nested) subgraphs. " + f"Use a fully qualified id like '5:9:{id_str}'." + ) + target_full_node_id = matches[0] if matches else None + + # Resolve via any_input + unique_id if still not found + if target_full_node_id is None and any_input is not None and unique_id_str: + # If unique_id is fully qualified, select that exact node + wts_full_id = None + if ":" in unique_id_str: + for fid, _, node in all_nodes: + if fid == unique_id_str and node.get("type") == "WidgetToString": + wts_full_id = fid + break + if wts_full_id is None: + raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'") + found_scope_key = parent_scope_of(wts_full_id) + else: + # Infer scope from prompt keys when unqualified + found_scope_key = resolve_scope_from_unique_id(unique_id_str) + candidates = [] + if found_scope_key: + candidates.append(f"{found_scope_key}:{unique_id_str}") + else: + candidates.append(unique_id_str) + + for fid, scope_key, node in all_nodes: + if node.get("type") == "WidgetToString" and fid in candidates: + wts_full_id = fid + if not found_scope_key: + found_scope_key = parent_scope_of(fid) + break + + if wts_full_id is None: + raise ValueError(f"No WidgetToString found for unique_id '{unique_id_str}'") + + # With the WidgetToString located, read its any_input link id + wts_node = next(node for fid, _, node in all_nodes if fid == wts_full_id) + for node_input in (wts_node.get("inputs") or []): + if node_input.get("name") == "any_input": + active_link_id = node_input.get("link") + break + + if active_link_id is None: + raise ValueError(f"WidgetToString '{wts_full_id}' has no 'any_input' link") + + # Resolve the producer of that link within the correct scope + target_full_node_id = link_to_node_map.get((found_scope_key or "", active_link_id)) + if target_full_node_id is None: + raise ValueError( + f"Could not resolve link {active_link_id} in scope '{found_scope_key}'. " + f"The subgraph clone’s links may not have been discovered." + ) + + if target_full_node_id is None: + raise ValueError("No matching node found for the given title, id, or any_input") + + values = prompt.get(str(target_full_node_id)) + if not values: + raise ValueError(f"No prompt entry found for node id: {target_full_node_id}") - values = prompt[str(node_id)] if "inputs" in values: if return_all: - # Format items based on type formatted_items = [] for k, v in values["inputs"].items(): if isinstance(v, float): @@ -906,7 +1084,7 @@ The 'any_input' is required for making sure the node you want the value from exi v = str(v) return (v, ) else: - raise NameError(f"Widget not found: {node_id}.{widget_name}") + raise NameError(f"Widget not found: {target_full_node_id}.{widget_name}") return (', '.join(results).strip(', '), ) class DummyOut: @@ -915,7 +1093,7 @@ class DummyOut: def INPUT_TYPES(cls): return { "required": { - "any_input": (IO.ANY, ), + "any_input": (IO.ANY, ), } } @@ -930,7 +1108,7 @@ A way to get previews in the UI without saving anything to disk. def dummy(self, any_input): return (any_input,) - + class FlipSigmasAdjusted: @classmethod def INPUT_TYPES(s): @@ -947,7 +1125,7 @@ class FlipSigmasAdjusted: FUNCTION = "get_sigmas_adjusted" def get_sigmas_adjusted(self, sigmas, divide_by_last_sigma, divide_by, offset_by): - + sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 @@ -958,9 +1136,9 @@ class FlipSigmasAdjusted: if 0 <= offset_index < len(sigmas): adjusted_sigmas[i] = sigmas[offset_index] else: - adjusted_sigmas[i] = 0.0001 + adjusted_sigmas[i] = 0.0001 if adjusted_sigmas[0] == 0: - adjusted_sigmas[0] = 0.0001 + adjusted_sigmas[0] = 0.0001 if divide_by_last_sigma: adjusted_sigmas = adjusted_sigmas / adjusted_sigmas[-1] @@ -968,16 +1146,16 @@ class FlipSigmasAdjusted: array_string = np.array2string(sigma_np_array, precision=2, separator=', ', threshold=np.inf) adjusted_sigmas = adjusted_sigmas / divide_by return (adjusted_sigmas, array_string,) - + class CustomSigmas: @classmethod def INPUT_TYPES(s): return {"required": - { - "sigmas_string" :("STRING", {"default": "14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029","multiline": True}), - "interpolate_to_steps": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), - } - } + { + "sigmas_string" :("STRING", {"default": "14.615, 6.475, 3.861, 2.697, 1.886, 1.396, 0.963, 0.652, 0.399, 0.152, 0.029","multiline": True}), + "interpolate_to_steps": ("INT", {"default": 10,"min": 0, "max": 255, "step": 1}), + } + } RETURN_TYPES = ("SIGMAS",) RETURN_NAMES = ("SIGMAS",) CATEGORY = "KJNodes/noise" @@ -1001,7 +1179,7 @@ SVD: sigmas_tensor = self.loglinear_interp(sigmas_tensor, interpolate_to_steps + 1) sigmas_tensor[-1] = 0 return (sigmas_tensor.float(),) - + def loglinear_interp(self, t_steps, num_steps): """ Performs log-linear interpolation of a given array of decreasing numbers. @@ -1010,22 +1188,22 @@ SVD: xs = np.linspace(0, 1, len(t_steps_np)) ys = np.log(t_steps_np[::-1]) - + new_xs = np.linspace(0, 1, num_steps) new_ys = np.interp(new_xs, xs, ys) - + interped_ys = np.exp(new_ys)[::-1].copy() interped_ys_tensor = torch.tensor(interped_ys) return interped_ys_tensor - + class StringToFloatList: @classmethod def INPUT_TYPES(s): return {"required": - { - "string" :("STRING", {"default": "1, 2, 3", "multiline": True}), - } - } + { + "string" :("STRING", {"default": "1, 2, 3", "multiline": True}), + } + } RETURN_TYPES = ("FLOAT",) RETURN_NAMES = ("FLOAT",) CATEGORY = "KJNodes/misc" @@ -1035,28 +1213,28 @@ class StringToFloatList: float_list = [float(x.strip()) for x in string.split(',')] return (float_list,) - + class InjectNoiseToLatent: @classmethod def INPUT_TYPES(s): return {"required": { - "latents":("LATENT",), + "latents":("LATENT",), "strength": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 200.0, "step": 0.0001}), "noise": ("LATENT",), "normalize": ("BOOLEAN", {"default": False}), "average": ("BOOLEAN", {"default": False}), - }, + }, "optional":{ "mask": ("MASK", ), "mix_randn_amount": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.001}), "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}), } - } - + } + RETURN_TYPES = ("LATENT",) FUNCTION = "injectnoise" CATEGORY = "KJNodes/noise" - + def injectnoise(self, latents, strength, noise, normalize, average, mix_randn_amount=0, seed=None, mask=None): samples = latents["samples"].clone().cpu() noise = noise["samples"].clone().cpu() @@ -1079,22 +1257,22 @@ class InjectNoiseToLatent: generator = torch.manual_seed(seed) rand_noise = torch.randn(noised.size(), dtype=noised.dtype, layout=noised.layout, generator=generator, device="cpu") noised = noised + (mix_randn_amount * rand_noise) - + return ({"samples":noised},) - + class SoundReactive: @classmethod def INPUT_TYPES(s): - return {"required": { + return {"required": { "sound_level": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 99999, "step": 0.01}), "start_range_hz": ("INT", {"default": 150, "min": 0, "max": 9999, "step": 1}), "end_range_hz": ("INT", {"default": 2000, "min": 0, "max": 9999, "step": 1}), "multiplier": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 99999, "step": 0.01}), "smoothing_factor": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), "normalize": ("BOOLEAN", {"default": False}), - }, - } - + }, + } + RETURN_TYPES = ("FLOAT","INT",) RETURN_NAMES =("sound_level", "sound_level_int",) FUNCTION = "react" @@ -1104,7 +1282,7 @@ Reacts to the sound level of the input. Uses your browsers sound input options and requires. Meant to be used with realtime diffusion with autoqueue. """ - + def react(self, sound_level, start_range_hz, end_range_hz, smoothing_factor, multiplier, normalize): sound_level *= multiplier @@ -1113,12 +1291,12 @@ Meant to be used with realtime diffusion with autoqueue. sound_level /= 255 sound_level_int = int(sound_level) - return (sound_level, sound_level_int, ) - + return (sound_level, sound_level_int, ) + class GenerateNoise: @classmethod def INPUT_TYPES(s): - return {"required": { + return {"required": { "width": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "height": ("INT", {"default": 512,"min": 16, "max": 4096, "step": 1}), "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}), @@ -1126,7 +1304,7 @@ class GenerateNoise: "multiplier": ("FLOAT", {"default": 1.0,"min": 0.0, "max": 4096, "step": 0.01}), "constant_batch_noise": ("BOOLEAN", {"default": False}), "normalize": ("BOOLEAN", {"default": False}), - }, + }, "optional": { "model": ("MODEL", ), "sigmas": ("SIGMAS", ), @@ -1134,14 +1312,14 @@ class GenerateNoise: "shape": (["BCHW", "BCTHW","BTCHW",],), } } - + RETURN_TYPES = ("LATENT",) FUNCTION = "generatenoise" CATEGORY = "KJNodes/noise" DESCRIPTION = """ Generates noise for injection or to be used as empty latents on samplers with add_noise off. """ - + def generatenoise(self, batch_size, width, height, seed, multiplier, constant_batch_noise, normalize, sigmas=None, model=None, latent_channels=4, shape="BCHW"): generator = torch.manual_seed(seed) @@ -1163,7 +1341,7 @@ Generates noise for injection or to be used as empty latents on samplers with ad if constant_batch_noise: noise = noise[0].repeat(batch_size, 1, 1, 1) - + return ({"samples":noise}, ) def camera_embeddings(elevation, azimuth): @@ -1171,14 +1349,14 @@ def camera_embeddings(elevation, azimuth): azimuth = torch.as_tensor([azimuth]) embeddings = torch.stack( [ - torch.deg2rad( - (90 - elevation) - (90) - ), # Zero123 polar is 90-elevation - torch.sin(torch.deg2rad(azimuth)), - torch.cos(torch.deg2rad(azimuth)), - torch.deg2rad( - 90 - torch.full_like(elevation, 0) - ), + torch.deg2rad( + (90 - elevation) - (90) + ), # Zero123 polar is 90-elevation + torch.sin(torch.deg2rad(azimuth)), + torch.cos(torch.deg2rad(azimuth)), + torch.deg2rad( + 90 - torch.full_like(elevation, 0) + ), ], dim=-1).unsqueeze(1) return embeddings @@ -1204,8 +1382,8 @@ class StableZero123_BatchSchedule: "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n7:(1.0),\n15:(0.0)\n", "multiline": True}), "elevation_points_string": ("STRING", {"default": "0:(0.0),\n7:(0.0),\n15:(0.0)\n", "multiline": True}), - }} - + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" @@ -1224,14 +1402,14 @@ class StableZero123_BatchSchedule: return 1 - (1 - t) * (1 - t) def ease_in_out(t): return 3 * t * t - 2 * t * t * t - + # Parse the azimuth input string into a list of tuples azimuth_points = [] azimuth_points_string = azimuth_points_string.rstrip(',\n') for point_str in azimuth_points_string.split(','): frame_str, azimuth_str = point_str.split(':') frame = int(frame_str.strip()) - azimuth = float(azimuth_str.strip()[1:-1]) + azimuth = float(azimuth_str.strip()[1:-1]) azimuth_points.append((frame, azimuth)) # Sort the points by frame number azimuth_points.sort(key=lambda x: x[0]) @@ -1242,7 +1420,7 @@ class StableZero123_BatchSchedule: for point_str in elevation_points_string.split(','): frame_str, elevation_str = point_str.split(':') frame = int(frame_str.strip()) - elevation_val = float(elevation_str.strip()[1:-1]) + elevation_val = float(elevation_str.strip()[1:-1]) elevation_points.append((frame, elevation_val)) # Sort the points by frame number elevation_points.sort(key=lambda x: x[0]) @@ -1255,7 +1433,7 @@ class StableZero123_BatchSchedule: positive_pooled_out = [] negative_cond_out = [] negative_pooled_out = [] - + #azimuth interpolation for i in range(batch_size): # Find the interpolated azimuth for the current frame @@ -1275,7 +1453,7 @@ class StableZero123_BatchSchedule: fraction = ease_out(fraction) elif interpolation == "ease_in_out": fraction = ease_in_out(fraction) - + # Use the new interpolate_angle function interpolated_azimuth = interpolate_angle(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction) else: @@ -1296,7 +1474,7 @@ class StableZero123_BatchSchedule: fraction = ease_out(fraction) elif interpolation == "ease_in_out": fraction = ease_in_out(fraction) - + interpolated_elevation = interpolate_angle(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction) else: interpolated_elevation = elevation_points[prev_elevation_point][1] @@ -1337,8 +1515,8 @@ class SV3D_BatchSchedule: "interpolation": (["linear", "ease_in", "ease_out", "ease_in_out"],), "azimuth_points_string": ("STRING", {"default": "0:(0.0),\n9:(180.0),\n20:(360.0)\n", "multiline": True}), "elevation_points_string": ("STRING", {"default": "0:(0.0),\n9:(0.0),\n20:(0.0)\n", "multiline": True}), - }} - + }} + RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT") RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" @@ -1362,14 +1540,14 @@ https://huggingface.co/stabilityai/sv3d return 1 - (1 - t) * (1 - t) def ease_in_out(t): return 3 * t * t - 2 * t * t * t - + # Parse the azimuth input string into a list of tuples azimuth_points = [] azimuth_points_string = azimuth_points_string.rstrip(',\n') for point_str in azimuth_points_string.split(','): frame_str, azimuth_str = point_str.split(':') frame = int(frame_str.strip()) - azimuth = float(azimuth_str.strip()[1:-1]) + azimuth = float(azimuth_str.strip()[1:-1]) azimuth_points.append((frame, azimuth)) # Sort the points by frame number azimuth_points.sort(key=lambda x: x[0]) @@ -1380,7 +1558,7 @@ https://huggingface.co/stabilityai/sv3d for point_str in elevation_points_string.split(','): frame_str, elevation_str = point_str.split(':') frame = int(frame_str.strip()) - elevation_val = float(elevation_str.strip()[1:-1]) + elevation_val = float(elevation_str.strip()[1:-1]) elevation_points.append((frame, elevation_val)) # Sort the points by frame number elevation_points.sort(key=lambda x: x[0]) @@ -1408,7 +1586,7 @@ https://huggingface.co/stabilityai/sv3d fraction = ease_out(fraction) elif interpolation == "ease_in_out": fraction = ease_in_out(fraction) - + interpolated_azimuth = linear_interpolate(azimuth_points[prev_point][1], azimuth_points[next_point][1], fraction) else: interpolated_azimuth = azimuth_points[prev_point][1] @@ -1430,7 +1608,7 @@ https://huggingface.co/stabilityai/sv3d fraction = ease_out(fraction) elif interpolation == "ease_in_out": fraction = ease_in_out(fraction) - + interpolated_elevation = linear_interpolate(elevation_points[prev_elevation_point][1], elevation_points[next_elevation_point][1], fraction) else: interpolated_elevation = elevation_points[prev_elevation_point][1] @@ -1455,7 +1633,7 @@ class LoadResAdapterNormalization: "required": { "model": ("MODEL",), "resadapter_path": (folder_paths.get_filename_list("checkpoints"), ) - } + } } RETURN_TYPES = ("MODEL",) @@ -1488,7 +1666,7 @@ class LoadResAdapterNormalization: raise Exception("Could not patch model, this way of patching was added to ComfyUI on March 3rd 2024, is your ComfyUI up to date?") print("ResAdapter: Added resnet normalization patches") return (model_clone, ) - + class Superprompt: @classmethod def INPUT_TYPES(s): @@ -1497,7 +1675,7 @@ class Superprompt: "instruction_prompt": ("STRING", {"default": 'Expand the following prompt to add more detail', "multiline": True}), "prompt": ("STRING", {"default": '', "multiline": True, "forceInput": True}), "max_new_tokens": ("INT", {"default": 128, "min": 1, "max": 4096, "step": 1}), - } + } } RETURN_TYPES = ("STRING",) @@ -1518,28 +1696,28 @@ https://huggingface.co/roborovski/superprompt-v1 checkpoint_path = os.path.join(script_directory, "models","superprompt-v1") if not os.path.exists(checkpoint_path): - print(f"Downloading model to: {checkpoint_path}") - from huggingface_hub import snapshot_download - snapshot_download(repo_id="roborovski/superprompt-v1", - local_dir=checkpoint_path, - local_dir_use_symlinks=False) + print(f"Downloading model to: {checkpoint_path}") + from huggingface_hub import snapshot_download + snapshot_download(repo_id="roborovski/superprompt-v1", + local_dir=checkpoint_path, + local_dir_use_symlinks=False) tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small", legacy=False) model = T5ForConditionalGeneration.from_pretrained(checkpoint_path, device_map=device) model.to(device) input_text = instruction_prompt + ": " + prompt - + input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device) outputs = model.generate(input_ids, max_new_tokens=max_new_tokens) out = (tokenizer.decode(outputs[0])) out = out.replace('', '') out = out.replace('', '') - + return (out, ) class CameraPoseVisualizer: - + @classmethod def INPUT_TYPES(s): return {"required": { @@ -1550,12 +1728,12 @@ class CameraPoseVisualizer: "use_exact_fx": ("BOOLEAN", {"default": False}), "relative_c2w": ("BOOLEAN", {"default": True}), "use_viewer": ("BOOLEAN", {"default": False}), - }, + }, "optional": { "cameractrl_poses": ("CAMERACTRL_POSES", {"default": None}), } - } - + } + RETURN_TYPES = ("IMAGE",) FUNCTION = "plot" CATEGORY = "KJNodes/misc" @@ -1563,7 +1741,7 @@ class CameraPoseVisualizer: Visualizes the camera poses, from Animatediff-Evolved CameraCtrl Pose or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. """ - + def plot(self, pose_file_path, scale, base_xval, zval, use_exact_fx, relative_c2w, use_viewer, cameractrl_poses=None): import matplotlib as mpl import matplotlib.pyplot as plt @@ -1616,7 +1794,7 @@ or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. for frame_idx, c2w in enumerate(c2ws): self.extrinsic2pyramid(c2w, frame_idx / total_frames, hw_ratio=1/1, base_xval=base_xval, - zval=(fxs[frame_idx] if use_exact_fx else zval)) + zval=(fxs[frame_idx] if use_exact_fx else zval)) # Create the colorbar cmap = mpl.cm.rainbow @@ -1633,7 +1811,7 @@ or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. # Assuming you want to set the ticks at every 10th frame ticks = np.arange(0, total_frames, 10) colorbar.ax.yaxis.set_ticks(ticks) - + plt.title('') plt.draw() buf = io.BytesIO() @@ -1652,16 +1830,16 @@ or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. import matplotlib.pyplot as plt from mpl_toolkits.mplot3d.art3d import Poly3DCollection vertex_std = np.array([[0, 0, 0, 1], - [base_xval, -base_xval * hw_ratio, zval, 1], - [base_xval, base_xval * hw_ratio, zval, 1], - [-base_xval, base_xval * hw_ratio, zval, 1], - [-base_xval, -base_xval * hw_ratio, zval, 1]]) + [base_xval, -base_xval * hw_ratio, zval, 1], + [base_xval, base_xval * hw_ratio, zval, 1], + [-base_xval, base_xval * hw_ratio, zval, 1], + [-base_xval, -base_xval * hw_ratio, zval, 1]]) vertex_transformed = vertex_std @ extrinsic.T meshes = [[vertex_transformed[0, :-1], vertex_transformed[1][:-1], vertex_transformed[2, :-1]], - [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], - [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], - [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], - [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] + [vertex_transformed[0, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1]], + [vertex_transformed[0, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]], + [vertex_transformed[0, :-1], vertex_transformed[4, :-1], vertex_transformed[1, :-1]], + [vertex_transformed[1, :-1], vertex_transformed[2, :-1], vertex_transformed[3, :-1], vertex_transformed[4, :-1]]] color = color_map if isinstance(color_map, str) else plt.cm.rainbow(color_map) @@ -1692,9 +1870,9 @@ or a .txt file with RealEstate camera intrinsics and coordinates, in a 3D plot. ret_poses = [np.linalg.inv(w2c) for w2c in w2cs] ret_poses = [transform_matrix @ x for x in ret_poses] return np.array(ret_poses, dtype=np.float32) - - - + + + class CheckpointPerturbWeights: @classmethod @@ -1705,7 +1883,7 @@ class CheckpointPerturbWeights: "final_layer": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), "rest_of_the_blocks": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 10.0, "step": 0.001}), "seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}), - } + } } RETURN_TYPES = ("MODEL",) FUNCTION = "mod" @@ -1729,7 +1907,7 @@ class CheckpointPerturbWeights: pbar = ProgressBar(len(keys)) for k in keys: v = dict[k] - print(f'{k}: {v.std()}') + print(f'{k}: {v.std()}') if k.startswith('joint_blocks'): multiplier = joint_blocks elif k.startswith('final_layer'): @@ -1740,16 +1918,16 @@ class CheckpointPerturbWeights: pbar.update(1) model_copy.model.diffusion_model.load_state_dict(dict) return model_copy, - + class DifferentialDiffusionAdvanced(): @classmethod def INPUT_TYPES(s): return {"required": { - "model": ("MODEL", ), - "samples": ("LATENT",), - "mask": ("MASK",), - "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}), - }} + "model": ("MODEL", ), + "samples": ("LATENT",), + "mask": ("MASK",), + "multiplier": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.001}), + }} RETURN_TYPES = ("MODEL", "LATENT") FUNCTION = "apply" CATEGORY = "_for_testing" @@ -1778,7 +1956,7 @@ class DifferentialDiffusionAdvanced(): threshold = (current_ts - ts_to) / (ts_from - ts_to) / self.multiplier return (denoise_mask >= threshold).to(denoise_mask.dtype) - + class FluxBlockLoraSelect: def __init__(self): self.loaded_lora = None @@ -1795,7 +1973,7 @@ class FluxBlockLoraSelect: arg_dict["single_blocks.{}.".format(i)] = argument return {"required": arg_dict} - + RETURN_TYPES = ("SELECTEDDITBLOCKS", ) RETURN_NAMES = ("blocks", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) @@ -1806,7 +1984,7 @@ class FluxBlockLoraSelect: def load_lora(self, **kwargs): return (kwargs,) - + class HunyuanVideoBlockLoraSelect: @classmethod def INPUT_TYPES(s): @@ -1820,7 +1998,7 @@ class HunyuanVideoBlockLoraSelect: arg_dict["single_blocks.{}.".format(i)] = argument return {"required": arg_dict} - + RETURN_TYPES = ("SELECTEDDITBLOCKS", ) RETURN_NAMES = ("blocks", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) @@ -1842,7 +2020,7 @@ class Wan21BlockLoraSelect: arg_dict["blocks.{}.".format(i)] = argument return {"required": arg_dict} - + RETURN_TYPES = ("SELECTEDDITBLOCKS", ) RETURN_NAMES = ("blocks", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.",) @@ -1853,7 +2031,7 @@ class Wan21BlockLoraSelect: def load_lora(self, **kwargs): return (kwargs,) - + class DiTBlockLoraLoader: def __init__(self): self.loaded_lora = None @@ -1861,17 +2039,17 @@ class DiTBlockLoraLoader: @classmethod def INPUT_TYPES(s): return {"required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - - }, - "optional": { - "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), - "opt_lora_path": ("STRING", {"forceInput": True, "tooltip": "Absolute path of the LoRA."}), - "blocks": ("SELECTEDDITBLOCKS",), - } - } - + "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), + "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), + + }, + "optional": { + "lora_name": (folder_paths.get_filename_list("loras"), {"tooltip": "The name of the LoRA."}), + "opt_lora_path": ("STRING", {"forceInput": True, "tooltip": "Absolute path of the LoRA."}), + "blocks": ("SELECTEDDITBLOCKS",), + } + } + RETURN_TYPES = ("MODEL", "STRING", ) RETURN_NAMES = ("model", "rank", ) OUTPUT_TOOLTIPS = ("The modified diffusion model.", "possible rank of the LoRA.") @@ -1879,21 +2057,21 @@ class DiTBlockLoraLoader: CATEGORY = "KJNodes/experimental" def load_lora(self, model, strength_model, lora_name=None, opt_lora_path=None, blocks=None): - + import comfy.lora if opt_lora_path: lora_path = opt_lora_path else: lora_path = folder_paths.get_full_path("loras", lora_name) - + lora = None if self.loaded_lora is not None: if self.loaded_lora[0] == lora_path: lora = self.loaded_lora[1] else: self.loaded_lora = None - + if lora is None: lora = load_torch_file(lora_path, safe_load=True) self.loaded_lora = (lora_path, lora) @@ -1957,15 +2135,15 @@ class DiTBlockLoraLoader: if model is not None: new_modelpatcher = model.clone() - k = new_modelpatcher.add_patches(loaded, strength_model) - + k = new_modelpatcher.add_patches(loaded, strength_model) + k = set(k) for x in loaded: if (x not in k): print("NOT LOADED {}".format(x)) return (new_modelpatcher, rank) - + class CustomControlNetWeightsFluxFromList: @classmethod def INPUT_TYPES(s): @@ -1979,7 +2157,7 @@ class CustomControlNetWeightsFluxFromList: "autosize": ("ACNAUTOSIZE", {"padding": 0}), } } - + RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",) RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT") FUNCTION = "load_weights" @@ -1989,7 +2167,7 @@ class CustomControlNetWeightsFluxFromList: def load_weights(self, list_of_floats: list[float], uncond_multiplier: float=1.0, cn_extras: dict[str]={}): - + adv_control = importlib.import_module("ComfyUI-Advanced-ControlNet.adv_control") ControlWeights = adv_control.utils.ControlWeights TimestepKeyframeGroup = adv_control.utils.TimestepKeyframeGroup @@ -1998,7 +2176,7 @@ class CustomControlNetWeightsFluxFromList: weights = ControlWeights.controlnet(weights_input=list_of_floats, uncond_multiplier=uncond_multiplier, extras=cn_extras) print(weights.weights_input) return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights))) - + SHAKKERLABS_UNION_CONTROLNET_TYPES = { "canny": 0, "tile": 1, @@ -2051,7 +2229,7 @@ class ModelSaveKJ: def save(self, model, filename_prefix, model_key_prefix, prompt=None, extra_pnginfo=None): from comfy.utils import save_torch_file full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - + output_checkpoint = f"{filename}_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) @@ -2077,7 +2255,7 @@ class ModelSaveKJ: os.makedirs(full_output_folder) save_torch_file(new_sd, os.path.join(full_output_folder, output_checkpoint)) return {} - + class StyleModelApplyAdvanced: @classmethod def INPUT_TYPES(s): @@ -2107,12 +2285,12 @@ class AudioConcatenate: "audio1": ("AUDIO",), "audio2": ("AUDIO",), "direction": ( - [ 'right', - 'left', - ], - { - "default": 'right' - }), + [ 'right', + 'left', + ], + { + "default": 'right' + }), }} RETURN_TYPES = ("AUDIO",) @@ -2127,7 +2305,7 @@ Concatenates the audio1 to audio2 in the specified direction. sample_rate_2 = audio2["sample_rate"] if sample_rate_1 != sample_rate_2: raise Exception("Sample rates of the two audios do not match") - + waveform_1 = audio1["waveform"] print(waveform_1.shape) waveform_2 = audio2["waveform"] @@ -2183,7 +2361,7 @@ class LeapfusionHunyuanI2V: inp[:, :, [index], :, :] = torch.zeros(1) return apply_model(inp, timestep, **c) return unet_wrapper - + samples = latent["samples"] * 0.476986 * strength m = model.clone() m.set_model_unet_function_wrapper(outer_wrapper(samples, index, start_percent, end_percent)) @@ -2293,9 +2471,9 @@ class VAELoaderKJ: "required": { "vae_name": (s.vae_list(), ), "device": (["main_device", "cpu"],), "weight_dtype": (["bf16", "fp16", "fp32" ],), - } - } - + } + } + RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" CATEGORY = "KJNodes/vae" @@ -2349,20 +2527,20 @@ class Guider_ScheduledCFG(CFGGuider): uncond = None cfg = 1.0 - return sampling_function(self.inner_model, x, timestep, uncond, self.conds.get("positive", None), cfg, model_options=model_options, seed=seed) - + return sampling_function(self.inner_model, x, timestep, uncond, self.conds.get("positive", None), cfg, model_options=model_options, seed=seed) + class ScheduledCFGGuidance: @classmethod def INPUT_TYPES(s): - return {"required": { - "model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.01}), - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step":0.01}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01}), - }, - } + return {"required": { + "model": ("MODEL",), + "positive": ("CONDITIONING", ), + "negative": ("CONDITIONING", ), + "cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 100.0, "step": 0.01}), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step":0.01}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01}), + }, + } RETURN_TYPES = ("GUIDER",) FUNCTION = "get_guider" CATEGORY = "KJNodes/experimental" @@ -2372,11 +2550,11 @@ cfg input can be a list of floats matching step count, or a single float for all """ def get_guider(self, model, cfg, positive, negative, start_percent, end_percent): - guider = Guider_ScheduledCFG(model) + guider = Guider_ScheduledCFG(model) guider.set_conds(positive, negative) guider.set_cfg(cfg, start_percent, end_percent) return (guider, ) - + class ApplyRifleXRoPE_WanVideo: @classmethod @@ -2386,7 +2564,7 @@ class ApplyRifleXRoPE_WanVideo: "model": ("MODEL",), "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}), "k": ("INT", {"default": 6, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}), - } + } } RETURN_TYPES = ("MODEL",) @@ -2397,23 +2575,23 @@ class ApplyRifleXRoPE_WanVideo: def patch(self, model, latent, k): model_class = model.model.diffusion_model - + model_clone = model.clone() num_frames = latent["samples"].shape[2] d = model_class.dim // model_class.num_heads rope_embedder = EmbedND_RifleX( - d, - 10000.0, + d, + 10000.0, [d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)], num_frames, k - ) - + ) + model_clone.add_object_patch(f"diffusion_model.rope_embedder", rope_embedder) - + return (model_clone, ) - + class ApplyRifleXRoPE_HunuyanVideo: @classmethod def INPUT_TYPES(s): @@ -2422,7 +2600,7 @@ class ApplyRifleXRoPE_HunuyanVideo: "model": ("MODEL",), "latent": ("LATENT", {"tooltip": "Only used to get the latent count"}), "k": ("INT", {"default": 4, "min": 1, "max": 100, "step": 1, "tooltip": "Index of intrinsic frequency"}), - } + } } RETURN_TYPES = ("MODEL",) @@ -2433,20 +2611,20 @@ class ApplyRifleXRoPE_HunuyanVideo: def patch(self, model, latent, k): model_class = model.model.diffusion_model - + model_clone = model.clone() num_frames = latent["samples"].shape[2] pe_embedder = EmbedND_RifleX( - model_class.params.hidden_size // model_class.params.num_heads, - model_class.params.theta, - model_class.params.axes_dim, + model_class.params.hidden_size // model_class.params.num_heads, + model_class.params.theta, + model_class.params.axes_dim, num_frames, k - ) - + ) + model_clone.add_object_patch(f"diffusion_model.pe_embedder", pe_embedder) - + return (model_clone, ) def rope_riflex(pos, dim, theta, L_test, k): @@ -2495,18 +2673,18 @@ class Timer: class TimerNodeKJ: @classmethod - + def INPUT_TYPES(s): - return { - "required": { - "any_input": (IO.ANY, ), - "mode": (["start", "stop"],), - "name": ("STRING", {"default": "Timer"}), - }, - "optional": { - "timer": ("TIMER",), - }, - } + return { + "required": { + "any_input": (IO.ANY, ), + "mode": (["start", "stop"],), + "name": ("STRING", {"default": "Timer"}), + }, + "optional": { + "timer": ("TIMER",), + }, + } RETURN_TYPES = (IO.ANY, "TIMER", "INT", ) RETURN_NAMES = ("any_output", "timer", "time") @@ -2516,12 +2694,12 @@ class TimerNodeKJ: def timer(self, mode, name, any_input=None, timer=None): if timer is None: if mode == "start": - timer = Timer(name=name) + timer = Timer(name=name) timer.start_time = time.time() return {"ui": { - "text": [f"{timer.start_time}"]}, - "result": (any_input, timer, 0) - } + "text": [f"{timer.start_time}"]}, + "result": (any_input, timer, 0) + } elif mode == "stop" and timer is not None: end_time = time.time() timer.elapsed = int((end_time - timer.start_time) * 1000) @@ -2532,21 +2710,21 @@ class HunyuanVideoEncodeKeyframesToCond: @classmethod def INPUT_TYPES(s): return {"required": { - "model": ("MODEL",), - "positive": ("CONDITIONING", ), - "vae": ("VAE", ), - "start_frame": ("IMAGE", ), - "end_frame": ("IMAGE", ), - "num_frames": ("INT", {"default": 33, "min": 2, "max": 4096, "step": 1}), - "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), - "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), - }, - "optional": { - "negative": ("CONDITIONING", ), - } - } + "model": ("MODEL",), + "positive": ("CONDITIONING", ), + "vae": ("VAE", ), + "start_frame": ("IMAGE", ), + "end_frame": ("IMAGE", ), + "num_frames": ("INT", {"default": 33, "min": 2, "max": 4096, "step": 1}), + "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), + "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), + }, + "optional": { + "negative": ("CONDITIONING", ), + } + } RETURN_TYPES = ("MODEL", "CONDITIONING","CONDITIONING","LATENT") RETURN_NAMES = ("model", "positive", "negative", "latent") @@ -2560,7 +2738,7 @@ class HunyuanVideoEncodeKeyframesToCond: model_clone.add_object_patch("concat_keys", ("concat_image",)) - + x = (start_frame.shape[1] // 8) * 8 y = (start_frame.shape[2] // 8) * 8 @@ -2593,7 +2771,7 @@ class HunyuanVideoEncodeKeyframesToCond: if len(out) == 1: out.append(out[0]) return (model_clone, out[0], out[1], out_latent) - + class LazySwitchKJ: def __init__(self): @@ -2622,4 +2800,4 @@ class LazySwitchKJ: def switch(self, switch, on_false = None, on_true=None): value = on_true if switch else on_false - return (value,) \ No newline at end of file + return (value,)