mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 05:54:24 +08:00
Merge 7cecb6dbf8c3fd7eec6550f1ce21625d2497d01a into fd271dedfde6e192a1f1a025521070876e89e04a
This commit is contained in:
commit
e9cb8a966f
@ -145,6 +145,8 @@ parser.add_argument("--force-non-blocking", action="store_true", help="Force Com
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
|
||||
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--use-subprocess-workers", action="store_true", help="Execute each prompt in an isolated subprocess with complete GPU/ROCm context reset. Ensures clean state between jobs but adds startup overhead.")
|
||||
parser.add_argument("--subprocess-timeout", type=int, default=600, help="Timeout in seconds for subprocess execution (default: 600, only used with --use-subprocess-workers).")
|
||||
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
|
||||
class PerformanceFeature(enum.Enum):
|
||||
|
||||
145
comfy/execution_core.py
Normal file
145
comfy/execution_core.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Core execution logic shared between normal and subprocess execution modes."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
_active_worker = None
|
||||
|
||||
|
||||
def create_worker(server_instance):
|
||||
"""Create worker backend. Returns NativeWorker or SubprocessWorker."""
|
||||
global _active_worker
|
||||
from comfy.cli_args import args
|
||||
|
||||
server = WorkerServer(server_instance)
|
||||
|
||||
if args.use_subprocess_workers:
|
||||
from comfy.worker_process import SubprocessWorker
|
||||
worker = SubprocessWorker(server, timeout=args.subprocess_timeout)
|
||||
else:
|
||||
from comfy.worker_native import NativeWorker
|
||||
worker = NativeWorker(server)
|
||||
|
||||
_active_worker = worker
|
||||
return worker
|
||||
|
||||
|
||||
async def init_execution_environment():
|
||||
"""Load nodes and custom nodes. Returns number of node types loaded."""
|
||||
import nodes
|
||||
from comfy.cli_args import args
|
||||
|
||||
await nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
)
|
||||
return len(nodes.NODE_CLASS_MAPPINGS)
|
||||
|
||||
|
||||
def setup_progress_hook(server_instance, interrupt_checker):
|
||||
"""Set up global progress hook. interrupt_checker must raise on interrupt."""
|
||||
import comfy.utils
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
ctx = get_executing_context()
|
||||
if ctx:
|
||||
prompt_id = prompt_id or ctx.prompt_id
|
||||
node_id = node_id or ctx.node_id
|
||||
|
||||
interrupt_checker()
|
||||
|
||||
prompt_id = prompt_id or server_instance.last_prompt_id
|
||||
node_id = node_id or server_instance.last_node_id
|
||||
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
server_instance.send_sync("progress", {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}, server_instance.client_id)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
class WorkerServer:
|
||||
"""Protocol boundary: client_id, last_node_id, last_prompt_id, sockets_metadata, send_sync(), queue_updated()"""
|
||||
|
||||
_WRITABLE = {'client_id', 'last_node_id', 'last_prompt_id'}
|
||||
|
||||
def __init__(self, server):
|
||||
object.__setattr__(self, '_server', server)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._WRITABLE:
|
||||
setattr(self._server, name, value)
|
||||
else:
|
||||
raise AttributeError(f"WorkerServer does not accept attribute '{name}'")
|
||||
|
||||
@property
|
||||
def client_id(self):
|
||||
return self._server.client_id
|
||||
|
||||
@property
|
||||
def last_node_id(self):
|
||||
return self._server.last_node_id
|
||||
|
||||
@property
|
||||
def last_prompt_id(self):
|
||||
return self._server.last_prompt_id
|
||||
|
||||
@property
|
||||
def sockets_metadata(self):
|
||||
return self._server.sockets_metadata
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
self._server.send_sync(event, data, sid or self.client_id)
|
||||
|
||||
def queue_updated(self):
|
||||
self._server.queue_updated()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
_active_worker.interrupt(value)
|
||||
|
||||
|
||||
def _strip_sensitive(prompt):
|
||||
return prompt[:5] + prompt[6:]
|
||||
|
||||
|
||||
def prompt_worker(q, worker):
|
||||
"""Main prompt execution loop."""
|
||||
import execution
|
||||
|
||||
server = worker.server_instance
|
||||
|
||||
while True:
|
||||
queue_item = q.get(timeout=worker.get_gc_timeout())
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
server.last_prompt_id = prompt_id
|
||||
|
||||
extra_data = {**item[3], **item[5]}
|
||||
|
||||
result = worker.execute_prompt(item[2], prompt_id, extra_data, item[4], server=server)
|
||||
worker.mark_needs_gc()
|
||||
|
||||
q.task_done(
|
||||
item_id,
|
||||
result['history_result'],
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if result['success'] else 'error',
|
||||
completed=result['success'],
|
||||
messages=result['status_messages']
|
||||
),
|
||||
process_item=_strip_sensitive
|
||||
)
|
||||
|
||||
if server.client_id is not None:
|
||||
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
|
||||
|
||||
elapsed = time.perf_counter() - start_time
|
||||
if elapsed > 600:
|
||||
logging.info(f"Prompt executed in {time.strftime('%H:%M:%S', time.gmtime(elapsed))}")
|
||||
else:
|
||||
logging.info(f"Prompt executed in {elapsed:.2f} seconds")
|
||||
|
||||
worker.handle_flags(q.get_flags())
|
||||
95
comfy/worker_native.py
Normal file
95
comfy/worker_native.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""Native (in-process) worker for prompt execution."""
|
||||
|
||||
import time
|
||||
import gc
|
||||
|
||||
|
||||
class NativeWorker:
|
||||
"""Executes prompts in the same process as the server."""
|
||||
|
||||
def __init__(self, server_instance, interrupt_checker=None):
|
||||
self.server_instance = server_instance
|
||||
self.interrupt_checker = interrupt_checker
|
||||
self.executor = None
|
||||
self.last_gc_collect = 0
|
||||
self.need_gc = False
|
||||
self.gc_collect_interval = 10.0
|
||||
|
||||
async def initialize(self):
|
||||
"""Load nodes and set up executor. Returns node count."""
|
||||
from execution import PromptExecutor, CacheType
|
||||
from comfy.cli_args import args
|
||||
from comfy.execution_core import init_execution_environment, setup_progress_hook
|
||||
import comfy.model_management as mm
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
try:
|
||||
node_count = await init_execution_environment()
|
||||
finally:
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
interrupt_checker = self.interrupt_checker or mm.throw_exception_if_processing_interrupted
|
||||
setup_progress_hook(self.server_instance, interrupt_checker=interrupt_checker)
|
||||
|
||||
cache_type = CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = CacheType.NONE
|
||||
|
||||
self.executor = PromptExecutor(
|
||||
self.server_instance,
|
||||
cache_type=cache_type,
|
||||
cache_args={"lru": args.cache_lru, "ram": args.cache_ram}
|
||||
)
|
||||
return node_count
|
||||
|
||||
def execute_prompt(self, prompt, prompt_id, extra_data, execute_outputs, server=None):
|
||||
self.executor.execute(prompt, prompt_id, extra_data, execute_outputs)
|
||||
return {
|
||||
'success': self.executor.success,
|
||||
'history_result': self.executor.history_result,
|
||||
'status_messages': self.executor.status_messages,
|
||||
'prompt_id': prompt_id
|
||||
}
|
||||
|
||||
def handle_flags(self, flags):
|
||||
import comfy.model_management as mm
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
if flags.get("unload_models", free_memory):
|
||||
mm.unload_all_models()
|
||||
self.need_gc = True
|
||||
self.last_gc_collect = 0
|
||||
|
||||
if free_memory:
|
||||
if self.executor:
|
||||
self.executor.reset()
|
||||
self.need_gc = True
|
||||
self.last_gc_collect = 0
|
||||
|
||||
if self.need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - self.last_gc_collect) > self.gc_collect_interval:
|
||||
gc.collect()
|
||||
mm.soft_empty_cache()
|
||||
self.last_gc_collect = current_time
|
||||
self.need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
def interrupt(self, value=True):
|
||||
import comfy.model_management
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
def mark_needs_gc(self):
|
||||
self.need_gc = True
|
||||
|
||||
def get_gc_timeout(self):
|
||||
if self.need_gc:
|
||||
return max(self.gc_collect_interval - (time.perf_counter() - self.last_gc_collect), 0.0)
|
||||
return 1000.0
|
||||
179
comfy/worker_process.py
Normal file
179
comfy/worker_process.py
Normal file
@ -0,0 +1,179 @@
|
||||
"""Subprocess worker for isolated prompt execution with complete GPU/ROCm reset."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
import traceback
|
||||
|
||||
mp.set_start_method('spawn', force=True)
|
||||
|
||||
|
||||
def _deserialize_preview(msg):
|
||||
"""Deserialize preview image from IPC transport."""
|
||||
if not (isinstance(msg['data'], dict) and msg['data'].get('_serialized')):
|
||||
return msg
|
||||
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
s = msg['data']
|
||||
pil_image = Image.open(BytesIO(base64.b64decode(s['image_bytes'])))
|
||||
msg['data'] = ((s['image_type'], pil_image, s['max_size']), s['metadata'])
|
||||
return msg
|
||||
|
||||
|
||||
def _error_result(worker_id, prompt_id, error, tb=None):
|
||||
return {
|
||||
'success': False,
|
||||
'error': error,
|
||||
'traceback': tb,
|
||||
'history_result': {},
|
||||
'status_messages': [],
|
||||
'worker_id': worker_id,
|
||||
'prompt_id': prompt_id
|
||||
}
|
||||
|
||||
|
||||
def _kill_worker(worker, worker_id):
|
||||
if not worker.is_alive():
|
||||
return
|
||||
worker.terminate()
|
||||
worker.join(timeout=2)
|
||||
if worker.is_alive():
|
||||
logging.warning(f"Worker {worker_id} didn't terminate, killing")
|
||||
worker.kill()
|
||||
worker.join()
|
||||
|
||||
|
||||
class SubprocessWorker:
|
||||
"""Executes each prompt in an isolated subprocess with fresh GPU context."""
|
||||
|
||||
def __init__(self, server_instance, timeout=600):
|
||||
self.server_instance = server_instance
|
||||
self.timeout = timeout
|
||||
self.worker_counter = 0
|
||||
self.current_worker = None
|
||||
self.interrupt_event = None
|
||||
logging.info("SubprocessWorker created - each job will run in isolated process")
|
||||
|
||||
async def initialize(self):
|
||||
"""Load node definitions for prompt validation. Returns node count."""
|
||||
from comfy.execution_core import init_execution_environment
|
||||
return await init_execution_environment()
|
||||
|
||||
def handle_flags(self, flags):
|
||||
pass
|
||||
|
||||
def mark_needs_gc(self):
|
||||
pass
|
||||
|
||||
def get_gc_timeout(self):
|
||||
return 1000.0
|
||||
|
||||
def interrupt(self, value=True):
|
||||
if not value:
|
||||
return
|
||||
if self.interrupt_event:
|
||||
self.interrupt_event.set()
|
||||
if self.current_worker and self.current_worker.is_alive():
|
||||
self.current_worker.join(timeout=2)
|
||||
_kill_worker(self.current_worker, self.worker_counter)
|
||||
self.current_worker = None
|
||||
|
||||
def _relay_messages(self, message_queue, server):
|
||||
"""Relay queued messages to UI."""
|
||||
while not message_queue.empty():
|
||||
try:
|
||||
msg = _deserialize_preview(message_queue.get_nowait())
|
||||
if server:
|
||||
server.send_sync(msg['event'], msg['data'], msg['sid'])
|
||||
except:
|
||||
break
|
||||
|
||||
def execute_prompt(self, prompt, prompt_id, extra_data={}, execute_outputs=[], server=None):
|
||||
self.worker_counter += 1
|
||||
worker_id = self.worker_counter
|
||||
|
||||
job_queue = mp.Queue()
|
||||
result_queue = mp.Queue()
|
||||
message_queue = mp.Queue()
|
||||
self.interrupt_event = mp.Event()
|
||||
|
||||
client_id = extra_data.get('client_id')
|
||||
client_metadata = {}
|
||||
if client_id and hasattr(server, 'sockets_metadata'):
|
||||
client_metadata = server.sockets_metadata.get(client_id, {})
|
||||
|
||||
job_data = {
|
||||
'prompt': prompt,
|
||||
'prompt_id': prompt_id,
|
||||
'extra_data': extra_data,
|
||||
'execute_outputs': execute_outputs,
|
||||
'client_sockets_metadata': client_metadata
|
||||
}
|
||||
|
||||
from comfy.worker_process_child import worker_main
|
||||
worker = mp.Process(
|
||||
target=worker_main,
|
||||
args=(job_queue, result_queue, message_queue, self.interrupt_event, worker_id),
|
||||
name=f'ComfyUI-Worker-{worker_id}'
|
||||
)
|
||||
|
||||
logging.info(f"Starting worker {worker_id} for prompt {prompt_id}")
|
||||
self.current_worker = worker
|
||||
worker.start()
|
||||
job_queue.put(job_data)
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
result = None
|
||||
|
||||
while result is None:
|
||||
if self.interrupt_event.is_set():
|
||||
logging.info(f"Worker {worker_id} interrupted")
|
||||
if server:
|
||||
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
|
||||
return _error_result(worker_id, prompt_id, 'Execution interrupted by user')
|
||||
|
||||
if time.time() - start_time > self.timeout:
|
||||
raise TimeoutError()
|
||||
|
||||
self._relay_messages(message_queue, server)
|
||||
|
||||
try:
|
||||
result = result_queue.get(timeout=0.1)
|
||||
except mp.queues.Empty:
|
||||
pass
|
||||
|
||||
self._relay_messages(message_queue, server)
|
||||
|
||||
worker.join(timeout=5)
|
||||
if worker.is_alive():
|
||||
_kill_worker(worker, worker_id)
|
||||
|
||||
logging.info(f"Worker {worker_id} cleaned up (exit code: {worker.exitcode})")
|
||||
self.current_worker = None
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
error = f"Worker {worker_id} timed out after {self.timeout}s. Try --subprocess-timeout to increase."
|
||||
logging.error(error)
|
||||
_kill_worker(worker, worker_id)
|
||||
self.current_worker = None
|
||||
return _error_result(worker_id, prompt_id, error)
|
||||
|
||||
except Exception as e:
|
||||
error = f"Worker {worker_id} IPC error: {e}"
|
||||
logging.error(f"{error}\n{traceback.format_exc()}")
|
||||
_kill_worker(worker, worker_id)
|
||||
self.current_worker = None
|
||||
return _error_result(worker_id, prompt_id, error, traceback.format_exc())
|
||||
|
||||
finally:
|
||||
for q in (job_queue, result_queue, message_queue):
|
||||
q.close()
|
||||
try:
|
||||
q.join_thread()
|
||||
except:
|
||||
pass
|
||||
104
comfy/worker_process_child.py
Normal file
104
comfy/worker_process_child.py
Normal file
@ -0,0 +1,104 @@
|
||||
"""Subprocess worker child process entry point."""
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import traceback
|
||||
|
||||
|
||||
class IPCMessageServer:
|
||||
"""IPC-based message server for subprocess workers."""
|
||||
|
||||
def __init__(self, message_queue, client_id=None, sockets_metadata=None):
|
||||
self.message_queue = message_queue
|
||||
self.client_id = client_id
|
||||
self.last_node_id = None
|
||||
self.last_prompt_id = None
|
||||
self.sockets_metadata = sockets_metadata or {}
|
||||
|
||||
def send_sync(self, event, data, sid=None):
|
||||
from protocol import BinaryEventTypes
|
||||
from io import BytesIO
|
||||
import base64
|
||||
|
||||
if event == BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA and isinstance(data, tuple):
|
||||
preview_image, metadata = data
|
||||
image_type, pil_image, max_size = preview_image
|
||||
|
||||
buffer = BytesIO()
|
||||
pil_image.save(buffer, format=image_type)
|
||||
|
||||
data = {
|
||||
'_serialized': True,
|
||||
'image_type': image_type,
|
||||
'image_bytes': base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
||||
'max_size': max_size,
|
||||
'metadata': metadata
|
||||
}
|
||||
|
||||
self.message_queue.put_nowait({'event': event, 'data': data, 'sid': sid})
|
||||
|
||||
def queue_updated(self):
|
||||
pass
|
||||
|
||||
|
||||
def worker_main(job_queue, result_queue, message_queue, interrupt_event, worker_id):
|
||||
"""Subprocess worker entry point - spawned fresh for each execution."""
|
||||
job_data = None
|
||||
try:
|
||||
logging.basicConfig(level=logging.INFO, format=f'[Worker-{worker_id}] %(levelname)s: %(message)s')
|
||||
logging.info(f"Worker {worker_id} starting (PID: {mp.current_process().pid})")
|
||||
|
||||
import asyncio
|
||||
import comfy.model_management
|
||||
from comfy.worker_native import NativeWorker
|
||||
from comfy.execution_core import WorkerServer
|
||||
|
||||
logging.info(f"Worker {worker_id} initialized. Device: {comfy.model_management.get_torch_device()}")
|
||||
|
||||
job_data = job_queue.get(timeout=30)
|
||||
client_id = job_data.get('extra_data', {}).get('client_id')
|
||||
client_metadata = job_data.get('client_sockets_metadata', {})
|
||||
|
||||
sockets_metadata = {client_id: client_metadata} if client_id and client_metadata else {}
|
||||
ipc_server = IPCMessageServer(message_queue, client_id, sockets_metadata)
|
||||
server = WorkerServer(ipc_server)
|
||||
|
||||
def check_interrupt():
|
||||
if interrupt_event.is_set():
|
||||
raise comfy.model_management.InterruptProcessingException()
|
||||
|
||||
worker = NativeWorker(server, interrupt_checker=check_interrupt)
|
||||
|
||||
import comfy.execution_core
|
||||
comfy.execution_core._active_worker = worker
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
node_count = loop.run_until_complete(worker.initialize())
|
||||
logging.info(f"Worker {worker_id} loaded {node_count} node types")
|
||||
|
||||
result = worker.execute_prompt(
|
||||
job_data['prompt'],
|
||||
job_data['prompt_id'],
|
||||
job_data.get('extra_data', {}),
|
||||
job_data.get('execute_outputs', [])
|
||||
)
|
||||
result['worker_id'] = worker_id
|
||||
|
||||
logging.info(f"Worker {worker_id} completed successfully")
|
||||
result_queue.put(result)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Worker {worker_id} failed: {e}\n{traceback.format_exc()}")
|
||||
result_queue.put({
|
||||
'success': False,
|
||||
'error': str(e),
|
||||
'traceback': traceback.format_exc(),
|
||||
'history_result': {},
|
||||
'status_messages': [],
|
||||
'worker_id': worker_id,
|
||||
'prompt_id': job_data.get('prompt_id', 'unknown') if job_data else 'unknown'
|
||||
})
|
||||
|
||||
finally:
|
||||
logging.info(f"Worker {worker_id} exiting")
|
||||
142
main.py
142
main.py
@ -11,9 +11,6 @@ import itertools
|
||||
import utils.extra_config
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -176,16 +173,22 @@ if 'torch' in sys.modules:
|
||||
|
||||
import comfy.utils
|
||||
|
||||
import execution
|
||||
import server
|
||||
from protocol import BinaryEventTypes
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
import comfyui_version
|
||||
import app.logger
|
||||
import hook_breaker_ac10a0
|
||||
|
||||
# Import modules needed for server operation
|
||||
# GPU initialization happens lazily when GPU functions are called
|
||||
# In subprocess mode, main process won't call GPU functions - workers will
|
||||
if __name__ == "__main__":
|
||||
import execution
|
||||
import nodes
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
def cuda_malloc_warning():
|
||||
if args.use_subprocess_workers:
|
||||
return
|
||||
device = comfy.model_management.get_torch_device()
|
||||
device_name = comfy.model_management.get_torch_device_name(device)
|
||||
cuda_malloc_warning = False
|
||||
@ -197,84 +200,6 @@ def cuda_malloc_warning():
|
||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||
|
||||
|
||||
def prompt_worker(q, server_instance):
|
||||
current_time: float = 0.0
|
||||
cache_type = execution.CacheType.CLASSIC
|
||||
if args.cache_lru > 0:
|
||||
cache_type = execution.CacheType.LRU
|
||||
elif args.cache_ram > 0:
|
||||
cache_type = execution.CacheType.RAM_PRESSURE
|
||||
elif args.cache_none:
|
||||
cache_type = execution.CacheType.NONE
|
||||
|
||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||
last_gc_collect = 0
|
||||
need_gc = False
|
||||
gc_collect_interval = 10.0
|
||||
|
||||
while True:
|
||||
timeout = 1000.0
|
||||
if need_gc:
|
||||
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)
|
||||
|
||||
queue_item = q.get(timeout=timeout)
|
||||
if queue_item is not None:
|
||||
item, item_id = queue_item
|
||||
execution_start_time = time.perf_counter()
|
||||
prompt_id = item[1]
|
||||
server_instance.last_prompt_id = prompt_id
|
||||
|
||||
sensitive = item[5]
|
||||
extra_data = item[3].copy()
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
need_gc = True
|
||||
|
||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||
q.task_done(item_id,
|
||||
e.history_result,
|
||||
status=execution.PromptQueue.ExecutionStatus(
|
||||
status_str='success' if e.success else 'error',
|
||||
completed=e.success,
|
||||
messages=e.status_messages), process_item=remove_sensitive)
|
||||
if server_instance.client_id is not None:
|
||||
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||
|
||||
current_time = time.perf_counter()
|
||||
execution_time = current_time - execution_start_time
|
||||
|
||||
# Log Time in a more readable way after 10 minutes
|
||||
if execution_time > 600:
|
||||
execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time))
|
||||
logging.info(f"Prompt executed in {execution_time}")
|
||||
else:
|
||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||
|
||||
flags = q.get_flags()
|
||||
free_memory = flags.get("free_memory", False)
|
||||
|
||||
if flags.get("unload_models", free_memory):
|
||||
comfy.model_management.unload_all_models()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if free_memory:
|
||||
e.reset()
|
||||
need_gc = True
|
||||
last_gc_collect = 0
|
||||
|
||||
if need_gc:
|
||||
current_time = time.perf_counter()
|
||||
if (current_time - last_gc_collect) > gc_collect_interval:
|
||||
gc.collect()
|
||||
comfy.model_management.soft_empty_cache()
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
addresses = []
|
||||
for addr in address.split(","):
|
||||
@ -283,37 +208,6 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star
|
||||
server_instance.start_multi_address(addresses, call_on_start, verbose), server_instance.publish_loop()
|
||||
)
|
||||
|
||||
def hijack_progress(server_instance):
|
||||
def hook(value, total, preview_image, prompt_id=None, node_id=None):
|
||||
executing_context = get_executing_context()
|
||||
if prompt_id is None and executing_context is not None:
|
||||
prompt_id = executing_context.prompt_id
|
||||
if node_id is None and executing_context is not None:
|
||||
node_id = executing_context.node_id
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
if prompt_id is None:
|
||||
prompt_id = server_instance.last_prompt_id
|
||||
if node_id is None:
|
||||
node_id = server_instance.last_node_id
|
||||
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}
|
||||
get_progress_state().update_progress(node_id, value, total, preview_image)
|
||||
|
||||
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||
if preview_image is not None:
|
||||
# Only send old method if client doesn't support preview metadata
|
||||
if not feature_flags.supports_feature(
|
||||
server_instance.sockets_metadata,
|
||||
server_instance.client_id,
|
||||
"supports_preview_metadata",
|
||||
):
|
||||
server_instance.send_sync(
|
||||
BinaryEventTypes.UNENCODED_PREVIEW_IMAGE,
|
||||
preview_image,
|
||||
server_instance.client_id,
|
||||
)
|
||||
|
||||
comfy.utils.set_progress_bar_global_hook(hook)
|
||||
|
||||
|
||||
def cleanup_temp():
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
@ -356,20 +250,16 @@ def start_comfyui(asyncio_loop=None):
|
||||
if args.enable_manager and not args.disable_manager_ui:
|
||||
comfyui_manager.start()
|
||||
|
||||
hook_breaker_ac10a0.save_functions()
|
||||
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
))
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
from comfy.execution_core import create_worker, prompt_worker
|
||||
worker = create_worker(prompt_server)
|
||||
node_count = asyncio_loop.run_until_complete(worker.initialize())
|
||||
logging.info(f"Loaded {node_count} node types")
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, worker), name="PromptWorker").start()
|
||||
|
||||
cuda_malloc_warning()
|
||||
setup_database()
|
||||
|
||||
prompt_server.add_routes()
|
||||
hijack_progress(prompt_server)
|
||||
|
||||
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
|
||||
|
||||
if args.quick_test_for_ci:
|
||||
exit(0)
|
||||
|
||||
3
nodes.py
3
nodes.py
@ -50,7 +50,8 @@ def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
from comfy.execution_core import interrupt_processing as core_interrupt
|
||||
core_interrupt(value)
|
||||
|
||||
MAX_RESOLUTION=16384
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user