mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-08 21:44:33 +08:00
Merge 09c250184daf60c46182570df8b911891a19c4b9 into 56fa7dbe380cb5591c5542f8aa51ce2fc26beedf
This commit is contained in:
commit
a9c73cd85a
@ -1,6 +1,12 @@
|
||||
from __future__ import annotations
|
||||
from typing import Type, Literal
|
||||
# graph.py — grouped/batched scheduler on top of the updated ExecutionList
|
||||
# Implements model-class batching to reduce device/context swaps while preserving
|
||||
# the new execution_cache behavior added upstream.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Type, Literal, Optional
|
||||
|
||||
import os
|
||||
import nodes
|
||||
import asyncio
|
||||
import inspect
|
||||
@ -10,15 +16,19 @@ from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputType
|
||||
# NOTE: ExecutionBlocker code got moved to graph_utils.py to prevent torch being imported too soon during unit tests
|
||||
ExecutionBlocker = ExecutionBlocker
|
||||
|
||||
|
||||
class DependencyCycleError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class NodeNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicPrompt:
|
||||
def __init__(self, original_prompt):
|
||||
# The original prompt provided by the user
|
||||
@ -62,6 +72,7 @@ class DynamicPrompt:
|
||||
def get_original_prompt(self):
|
||||
return self.original_prompt
|
||||
|
||||
|
||||
def get_input_info(
|
||||
class_def: Type[ComfyNodeABC],
|
||||
input_name: str,
|
||||
@ -99,12 +110,13 @@ def get_input_info(
|
||||
extra_info = {}
|
||||
return input_type, input_category, extra_info
|
||||
|
||||
|
||||
class TopologicalSort:
|
||||
def __init__(self, dynprompt):
|
||||
self.dynprompt = dynprompt
|
||||
self.pendingNodes = {}
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.blockCount = {} # Number of nodes this node is directly blocked by
|
||||
self.blocking = {} # Which nodes are blocked by this node
|
||||
self.externalBlocks = 0
|
||||
self.unblockedEvent = asyncio.Event()
|
||||
|
||||
@ -165,6 +177,7 @@ class TopologicalSort:
|
||||
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
|
||||
self.externalBlocks += 1
|
||||
self.blockCount[node_id] += 1
|
||||
|
||||
def unblock():
|
||||
self.externalBlocks -= 1
|
||||
self.blockCount[node_id] -= 1
|
||||
@ -186,36 +199,49 @@ class TopologicalSort:
|
||||
def is_empty(self):
|
||||
return len(self.pendingNodes) == 0
|
||||
|
||||
|
||||
class ExecutionList(TopologicalSort):
|
||||
"""
|
||||
ExecutionList implements a topological dissolve of the graph. After a node is staged for execution,
|
||||
it can still be returned to the graph after having further dependencies added.
|
||||
ExecutionList implements a topological dissolve of the graph with batching.
|
||||
After a node is staged for execution, it can still be returned to the graph
|
||||
after having further dependencies added.
|
||||
|
||||
Batching: we favor running nodes of the same class_type back-to-back
|
||||
to reduce device/context thrash (e.g., model swaps). Within a batch we still
|
||||
apply UX-friendly priorities (output/async early, VAEDecode→preview, etc.).
|
||||
"""
|
||||
|
||||
def __init__(self, dynprompt, output_cache):
|
||||
super().__init__(dynprompt)
|
||||
self.output_cache = output_cache
|
||||
self.staged_node_id = None
|
||||
self.staged_node_id: Optional[str] = None
|
||||
|
||||
# Upstream execution cache (kept intact)
|
||||
self.execution_cache = {}
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
# Batching state
|
||||
self._current_group_class: Optional[str] = None
|
||||
|
||||
# ----------------------------- cache ---------------------------------
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
# Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
return value
|
||||
|
||||
@ -229,16 +255,93 @@ class ExecutionList(TopologicalSort):
|
||||
super().add_strong_link(from_node_id, from_socket, to_node_id)
|
||||
self.cache_link(from_node_id, to_node_id)
|
||||
|
||||
# --------------------------- group utils ------------------------------
|
||||
def _pick_largest_group(self, node_list):
|
||||
"""Return the class_type with the most representatives in node_list.
|
||||
Ties are resolved deterministically by class name."""
|
||||
counts = {}
|
||||
for nid in node_list:
|
||||
ctype = self.dynprompt.get_node(nid)["class_type"]
|
||||
counts[ctype] = counts.get(ctype, 0) + 1
|
||||
# max by (count, class_name) for deterministic tie-break
|
||||
return max(counts.items(), key=lambda kv: (kv[1], kv[0]))[0]
|
||||
|
||||
def _filter_by_group(self, node_list, group_cls):
|
||||
"""Keep only nodes that belong to the given class."""
|
||||
return [nid for nid in node_list if self.dynprompt.get_node(nid)["class_type"] == group_cls]
|
||||
|
||||
# ------------------------- node classification ------------------------
|
||||
def _is_output(self, node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'OUTPUT_NODE', False) is True
|
||||
|
||||
def _is_async(self, node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
# ------------------------- UX within a batch --------------------------
|
||||
def _pick_in_batch_with_ux(self, candidates):
|
||||
"""
|
||||
Original UX heuristics, but applied *within* the current batch.
|
||||
"""
|
||||
# 1) Output nodes ASAP
|
||||
for nid in candidates:
|
||||
if self._is_output(nid):
|
||||
return nid
|
||||
# 1b) Async nodes early to overlap
|
||||
for nid in candidates:
|
||||
if self._is_async(nid):
|
||||
return nid
|
||||
# 2) decoder-before-preview pattern (within the batch)
|
||||
for nid in candidates:
|
||||
for blocked in self.blocking[nid]:
|
||||
if self._is_output(blocked):
|
||||
return nid
|
||||
# 3) VAELoader -> VAEDecode -> preview (within the batch)
|
||||
for nid in candidates:
|
||||
for blocked in self.blocking[nid]:
|
||||
for blocked2 in self.blocking[blocked]:
|
||||
if self._is_output(blocked2):
|
||||
return nid
|
||||
# 4) Otherwise, first candidate
|
||||
return candidates[0]
|
||||
|
||||
# ------------------------- batch-aware picking ------------------------
|
||||
def ux_friendly_pick_node(self, available):
|
||||
"""
|
||||
Choose which ready node to execute next, honoring the current batch.
|
||||
When the current batch runs dry, switch to the largest ready group.
|
||||
"""
|
||||
|
||||
# Ensure current batch is still present; otherwise pick a new largest group.
|
||||
has_current = (
|
||||
self._current_group_class is not None and
|
||||
any(self.dynprompt.get_node(nid)["class_type"] == self._current_group_class for nid in available)
|
||||
)
|
||||
if not has_current:
|
||||
new_group = self._pick_largest_group(available)
|
||||
self._current_group_class = new_group
|
||||
|
||||
# Restrict to nodes of the current batch
|
||||
candidates = self._filter_by_group(available, self._current_group_class)
|
||||
return self._pick_in_batch_with_ux(candidates)
|
||||
|
||||
# --------------------------- staging / run ----------------------------
|
||||
async def stage_node_execution(self):
|
||||
assert self.staged_node_id is None
|
||||
if self.is_empty():
|
||||
return None, None, None
|
||||
|
||||
available = self.get_ready_nodes()
|
||||
|
||||
# If nothing ready but there are external blockers, wait for unblocks.
|
||||
while len(available) == 0 and self.externalBlocks > 0:
|
||||
# Wait for an external block to be released
|
||||
await self.unblockedEvent.wait()
|
||||
self.unblockedEvent.clear()
|
||||
available = self.get_ready_nodes()
|
||||
|
||||
if len(available) == 0:
|
||||
cycled_nodes = self.get_nodes_in_cycle()
|
||||
# Because cycles composed entirely of static nodes are caught during initial validation,
|
||||
@ -259,64 +362,30 @@ class ExecutionList(TopologicalSort):
|
||||
}
|
||||
return None, error_details, ex
|
||||
|
||||
# Batch-aware pick
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def ux_friendly_pick_node(self, node_list):
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
return True
|
||||
return False
|
||||
|
||||
# If an available node is async, do that first.
|
||||
# This will execute the asynchronous function earlier, reducing the overall time.
|
||||
def is_async(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return inspect.iscoroutinefunction(getattr(class_def, class_def.FUNCTION))
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id) or is_async(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
# If a node execution resolves to PENDING, return it to the pool
|
||||
# but keep the current batch so we continue batching next time.
|
||||
assert self.staged_node_id is not None
|
||||
self.staged_node_id = None
|
||||
|
||||
def complete_node_execution(self):
|
||||
node_id = self.staged_node_id
|
||||
self.pop_node(node_id)
|
||||
# Maintain current batch; it will switch automatically when empty.
|
||||
self.execution_cache.pop(node_id, None)
|
||||
self.execution_cache_listeners.pop(node_id, None)
|
||||
self.staged_node_id = None
|
||||
|
||||
# ------------------------- cycle detection ----------------------------
|
||||
def get_nodes_in_cycle(self):
|
||||
# We'll dissolve the graph in reverse topological order to leave only the nodes in the cycle.
|
||||
# We're skipping some of the performance optimizations from the original TopologicalSort to keep
|
||||
# the code simple (and because having a cycle in the first place is a catastrophic error)
|
||||
blocked_by = { node_id: {} for node_id in self.pendingNodes }
|
||||
blocked_by = {node_id: {} for node_id in self.pendingNodes}
|
||||
for from_node_id in self.blocking:
|
||||
for to_node_id in self.blocking[from_node_id]:
|
||||
if True in self.blocking[from_node_id][to_node_id].values():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user