ComfyUI/app/assets_scanner.py
2025-09-05 17:46:09 +03:00

456 lines
15 KiB
Python

import asyncio
import logging
import os
import uuid
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Literal, Optional, Sequence
import folder_paths
from . import assets_manager
from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders
from .database.db import create_session
from .database.services import check_fs_asset_exists_quick
LOGGER = logging.getLogger(__name__)
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
slow_queue_total: int = 0
slow_queue_finished: int = 0
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
# Internal diagnostics for logs
_fast_total_seen: int = 0
_fast_clean: int = 0
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
def _new_scan_id(root: RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
"""Schedule scans for the provided roots; returns progress snapshots.
Rules:
- Only roots in {models, input, output} are accepted.
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
- Otherwise a new task is created and started immediately.
- Files with zero size are skipped.
"""
normalized: list[RootType] = []
seen = set()
for r in roots or []:
rr = r.strip().lower()
if rr in ALLOWED_ROOTS and rr not in seen:
normalized.append(rr) # type: ignore
seen.add(rr)
if not normalized:
normalized = list(ALLOWED_ROOTS) # schedule all by default
results: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
RUNNING_TASKS[root] = asyncio.create_task(
_pipeline_for_root(root, prog, progress_cb=None),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def fast_reconcile_and_kickoff(
roots: Sequence[str] | None = None,
*,
progress_cb: Optional[Callable[[dict], None]] = None,
) -> schemas_out.AssetScanStatusResponse:
"""
Startup helper: do the fast pass now (so we know queue size),
start slow hashing in the background, return immediately.
"""
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
snaps: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
snaps.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
try:
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
except Exception as e:
_append_error(prog, phase="fast", path="", message=str(e))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Fast reconcile failed for %s", root)
snaps.append(prog)
continue
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
RUNNING_TASKS[root] = asyncio.create_task(
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
name=f"asset-hash:{root}",
)
snaps.append(prog)
return _status_response_for(snaps)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=_ts_to_iso(progress.scheduled_at),
started_at=_ts_to_iso(progress.started_at),
finished_at=_ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
slow_queue_total=progress.slow_queue_total,
slow_queue_finished=progress.slow_queue_finished,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
phase=e.get("phase", "slow"),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _pipeline_for_root(
root: RootType,
prog: ScanProgress,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
try:
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, phase="slow", path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _fast_reconcile_into_queue(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
"""
Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
and queue the rest for slow hashing.
"""
if root == "models":
files = _collect_models_files()
preset_discovered = len(files)
files_iter = asyncio.Queue()
for p in files:
await files_iter.put(p)
await files_iter.put(None) # sentinel for our local draining loop
elif root == "input":
base = folder_paths.get_input_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
elif root == "output":
base = folder_paths.get_output_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
else:
raise RuntimeError(f"Unsupported root: {root}")
prog.discovered = int(preset_discovered or 0)
queued = 0
checked = 0
clean = 0
# Single session for the whole fast pass
async with await create_session() as sess:
while True:
item = await files_iter.get()
files_iter.task_done()
if item is None:
break
abs_path = item
checked += 1
# Stat; skip empty/unreadable
try:
st = os.stat(abs_path, follow_symlinks=True)
if not st.st_size:
continue
size_bytes = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
except OSError as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
continue
# Known good -> count as processed immediately
try:
known = await check_fs_asset_exists_quick(
sess,
file_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
)
except Exception as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
known = False
if known:
clean += 1
prog.processed += 1 # preserve original semantics
else:
await state.queue.put(abs_path)
queued += 1
prog.slow_queue_total += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
})
prog._fast_total_seen = checked
prog._fast_clean = clean
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
"done": True,
})
state.closed = True
def _start_slow_workers(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
return
async def _worker(_worker_id: int):
while True:
item = await state.queue.get()
try:
if item is None:
return
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
except Exception as e:
_append_error(prog, phase="slow", path=item, message=str(e))
finally:
# Slow queue finished for this item; also counts toward overall processed
prog.slow_queue_finished += 1
prog.processed += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
})
finally:
state.queue.task_done()
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
async def _close_when_empty():
# When the fast phase closed the queue, push sentinels to end workers
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_empty())
async def _await_workers_then_finish(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
prog.finished_at = time.time()
prog.status = "completed"
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
"done": True,
})
def _collect_models_files() -> list[str]:
"""Collect absolute file paths from configured model buckets under models_dir."""
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
# ensure within allowed bases
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
try:
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
except Exception:
pass
if allowed:
out.append(abs_path)
return out
def _count_files_in_tree(base_abs: str) -> int:
if not os.path.isdir(base_abs):
return 0
total = 0
for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
total += len(filenames)
return total
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
"""
Walk base_dir in a worker thread and return a queue prefilled with all paths,
terminated by a single None sentinel for the draining loop in fast reconcile.
"""
q: asyncio.Queue = asyncio.Queue()
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
await q.put(None)
return q
def _walk_list():
paths: list[str] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
paths.append(os.path.abspath(os.path.join(dirpath, name)))
return paths
for p in await asyncio.to_thread(_walk_list):
await q.put(p)
await q.put(None)
return q
def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"phase": phase,
"at": _ts_to_iso(time.time()),
})
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None