ComfyUI/comfy/worker_process.py
2025-12-07 08:27:56 -08:00

180 lines
5.8 KiB
Python

"""Subprocess worker for isolated prompt execution with complete GPU/ROCm reset."""
import logging
import multiprocessing as mp
import time
import traceback
mp.set_start_method('spawn', force=True)
def _deserialize_preview(msg):
"""Deserialize preview image from IPC transport."""
if not (isinstance(msg['data'], dict) and msg['data'].get('_serialized')):
return msg
from PIL import Image
from io import BytesIO
import base64
s = msg['data']
pil_image = Image.open(BytesIO(base64.b64decode(s['image_bytes'])))
msg['data'] = ((s['image_type'], pil_image, s['max_size']), s['metadata'])
return msg
def _error_result(worker_id, prompt_id, error, tb=None):
return {
'success': False,
'error': error,
'traceback': tb,
'history_result': {},
'status_messages': [],
'worker_id': worker_id,
'prompt_id': prompt_id
}
def _kill_worker(worker, worker_id):
if not worker.is_alive():
return
worker.terminate()
worker.join(timeout=2)
if worker.is_alive():
logging.warning(f"Worker {worker_id} didn't terminate, killing")
worker.kill()
worker.join()
class SubprocessWorker:
"""Executes each prompt in an isolated subprocess with fresh GPU context."""
def __init__(self, server_instance, timeout=600):
self.server_instance = server_instance
self.timeout = timeout
self.worker_counter = 0
self.current_worker = None
self.interrupt_event = None
logging.info("SubprocessWorker created - each job will run in isolated process")
async def initialize(self):
"""Load node definitions for prompt validation. Returns node count."""
from comfy.execution_core import init_execution_environment
return await init_execution_environment()
def handle_flags(self, flags):
pass
def mark_needs_gc(self):
pass
def get_gc_timeout(self):
return 1000.0
def interrupt(self, value=True):
if not value:
return
if self.interrupt_event:
self.interrupt_event.set()
if self.current_worker and self.current_worker.is_alive():
self.current_worker.join(timeout=2)
_kill_worker(self.current_worker, self.worker_counter)
self.current_worker = None
def _relay_messages(self, message_queue, server):
"""Relay queued messages to UI."""
while not message_queue.empty():
try:
msg = _deserialize_preview(message_queue.get_nowait())
if server:
server.send_sync(msg['event'], msg['data'], msg['sid'])
except:
break
def execute_prompt(self, prompt, prompt_id, extra_data={}, execute_outputs=[], server=None):
self.worker_counter += 1
worker_id = self.worker_counter
job_queue = mp.Queue()
result_queue = mp.Queue()
message_queue = mp.Queue()
self.interrupt_event = mp.Event()
client_id = extra_data.get('client_id')
client_metadata = {}
if client_id and hasattr(server, 'sockets_metadata'):
client_metadata = server.sockets_metadata.get(client_id, {})
job_data = {
'prompt': prompt,
'prompt_id': prompt_id,
'extra_data': extra_data,
'execute_outputs': execute_outputs,
'client_sockets_metadata': client_metadata
}
from comfy.worker_process_child import worker_main
worker = mp.Process(
target=worker_main,
args=(job_queue, result_queue, message_queue, self.interrupt_event, worker_id),
name=f'ComfyUI-Worker-{worker_id}'
)
logging.info(f"Starting worker {worker_id} for prompt {prompt_id}")
self.current_worker = worker
worker.start()
job_queue.put(job_data)
try:
start_time = time.time()
result = None
while result is None:
if self.interrupt_event.is_set():
logging.info(f"Worker {worker_id} interrupted")
if server:
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
return _error_result(worker_id, prompt_id, 'Execution interrupted by user')
if time.time() - start_time > self.timeout:
raise TimeoutError()
self._relay_messages(message_queue, server)
try:
result = result_queue.get(timeout=0.1)
except mp.queues.Empty:
pass
self._relay_messages(message_queue, server)
worker.join(timeout=5)
if worker.is_alive():
_kill_worker(worker, worker_id)
logging.info(f"Worker {worker_id} cleaned up (exit code: {worker.exitcode})")
self.current_worker = None
return result
except TimeoutError:
error = f"Worker {worker_id} timed out after {self.timeout}s. Try --subprocess-timeout to increase."
logging.error(error)
_kill_worker(worker, worker_id)
self.current_worker = None
return _error_result(worker_id, prompt_id, error)
except Exception as e:
error = f"Worker {worker_id} IPC error: {e}"
logging.error(f"{error}\n{traceback.format_exc()}")
_kill_worker(worker, worker_id)
self.current_worker = None
return _error_result(worker_id, prompt_id, error, traceback.format_exc())
finally:
for q in (job_queue, result_queue, message_queue):
q.close()
try:
q.join_thread()
except:
pass