mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 00:14:31 +08:00
180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
"""Subprocess worker for isolated prompt execution with complete GPU/ROCm reset."""
|
|
|
|
import logging
|
|
import multiprocessing as mp
|
|
import time
|
|
import traceback
|
|
|
|
mp.set_start_method('spawn', force=True)
|
|
|
|
|
|
def _deserialize_preview(msg):
|
|
"""Deserialize preview image from IPC transport."""
|
|
if not (isinstance(msg['data'], dict) and msg['data'].get('_serialized')):
|
|
return msg
|
|
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
|
|
s = msg['data']
|
|
pil_image = Image.open(BytesIO(base64.b64decode(s['image_bytes'])))
|
|
msg['data'] = ((s['image_type'], pil_image, s['max_size']), s['metadata'])
|
|
return msg
|
|
|
|
|
|
def _error_result(worker_id, prompt_id, error, tb=None):
|
|
return {
|
|
'success': False,
|
|
'error': error,
|
|
'traceback': tb,
|
|
'history_result': {},
|
|
'status_messages': [],
|
|
'worker_id': worker_id,
|
|
'prompt_id': prompt_id
|
|
}
|
|
|
|
|
|
def _kill_worker(worker, worker_id):
|
|
if not worker.is_alive():
|
|
return
|
|
worker.terminate()
|
|
worker.join(timeout=2)
|
|
if worker.is_alive():
|
|
logging.warning(f"Worker {worker_id} didn't terminate, killing")
|
|
worker.kill()
|
|
worker.join()
|
|
|
|
|
|
class SubprocessWorker:
|
|
"""Executes each prompt in an isolated subprocess with fresh GPU context."""
|
|
|
|
def __init__(self, server_instance, timeout=600):
|
|
self.server_instance = server_instance
|
|
self.timeout = timeout
|
|
self.worker_counter = 0
|
|
self.current_worker = None
|
|
self.interrupt_event = None
|
|
logging.info("SubprocessWorker created - each job will run in isolated process")
|
|
|
|
async def initialize(self):
|
|
"""Load node definitions for prompt validation. Returns node count."""
|
|
from comfy.execution_core import init_execution_environment
|
|
return await init_execution_environment()
|
|
|
|
def handle_flags(self, flags):
|
|
pass
|
|
|
|
def mark_needs_gc(self):
|
|
pass
|
|
|
|
def get_gc_timeout(self):
|
|
return 1000.0
|
|
|
|
def interrupt(self, value=True):
|
|
if not value:
|
|
return
|
|
if self.interrupt_event:
|
|
self.interrupt_event.set()
|
|
if self.current_worker and self.current_worker.is_alive():
|
|
self.current_worker.join(timeout=2)
|
|
_kill_worker(self.current_worker, self.worker_counter)
|
|
self.current_worker = None
|
|
|
|
def _relay_messages(self, message_queue, server):
|
|
"""Relay queued messages to UI."""
|
|
while not message_queue.empty():
|
|
try:
|
|
msg = _deserialize_preview(message_queue.get_nowait())
|
|
if server:
|
|
server.send_sync(msg['event'], msg['data'], msg['sid'])
|
|
except:
|
|
break
|
|
|
|
def execute_prompt(self, prompt, prompt_id, extra_data={}, execute_outputs=[], server=None):
|
|
self.worker_counter += 1
|
|
worker_id = self.worker_counter
|
|
|
|
job_queue = mp.Queue()
|
|
result_queue = mp.Queue()
|
|
message_queue = mp.Queue()
|
|
self.interrupt_event = mp.Event()
|
|
|
|
client_id = extra_data.get('client_id')
|
|
client_metadata = {}
|
|
if client_id and hasattr(server, 'sockets_metadata'):
|
|
client_metadata = server.sockets_metadata.get(client_id, {})
|
|
|
|
job_data = {
|
|
'prompt': prompt,
|
|
'prompt_id': prompt_id,
|
|
'extra_data': extra_data,
|
|
'execute_outputs': execute_outputs,
|
|
'client_sockets_metadata': client_metadata
|
|
}
|
|
|
|
from comfy.worker_process_child import worker_main
|
|
worker = mp.Process(
|
|
target=worker_main,
|
|
args=(job_queue, result_queue, message_queue, self.interrupt_event, worker_id),
|
|
name=f'ComfyUI-Worker-{worker_id}'
|
|
)
|
|
|
|
logging.info(f"Starting worker {worker_id} for prompt {prompt_id}")
|
|
self.current_worker = worker
|
|
worker.start()
|
|
job_queue.put(job_data)
|
|
|
|
try:
|
|
start_time = time.time()
|
|
result = None
|
|
|
|
while result is None:
|
|
if self.interrupt_event.is_set():
|
|
logging.info(f"Worker {worker_id} interrupted")
|
|
if server:
|
|
server.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server.client_id)
|
|
return _error_result(worker_id, prompt_id, 'Execution interrupted by user')
|
|
|
|
if time.time() - start_time > self.timeout:
|
|
raise TimeoutError()
|
|
|
|
self._relay_messages(message_queue, server)
|
|
|
|
try:
|
|
result = result_queue.get(timeout=0.1)
|
|
except mp.queues.Empty:
|
|
pass
|
|
|
|
self._relay_messages(message_queue, server)
|
|
|
|
worker.join(timeout=5)
|
|
if worker.is_alive():
|
|
_kill_worker(worker, worker_id)
|
|
|
|
logging.info(f"Worker {worker_id} cleaned up (exit code: {worker.exitcode})")
|
|
self.current_worker = None
|
|
return result
|
|
|
|
except TimeoutError:
|
|
error = f"Worker {worker_id} timed out after {self.timeout}s. Try --subprocess-timeout to increase."
|
|
logging.error(error)
|
|
_kill_worker(worker, worker_id)
|
|
self.current_worker = None
|
|
return _error_result(worker_id, prompt_id, error)
|
|
|
|
except Exception as e:
|
|
error = f"Worker {worker_id} IPC error: {e}"
|
|
logging.error(f"{error}\n{traceback.format_exc()}")
|
|
_kill_worker(worker, worker_id)
|
|
self.current_worker = None
|
|
return _error_result(worker_id, prompt_id, error, traceback.format_exc())
|
|
|
|
finally:
|
|
for q in (job_queue, result_queue, message_queue):
|
|
q.close()
|
|
try:
|
|
q.join_thread()
|
|
except:
|
|
pass
|