From ce270ba090149baacdcb958dc7ba65a37d87e648 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Fri, 5 Sep 2025 17:46:09 +0300 Subject: [PATCH] added Assets Autoscan feature --- app/api/schemas_out.py | 16 +- app/assets_manager.py | 2 +- app/assets_scanner.py | 504 ++++++++++++++++++++++++++--------------- comfy/cli_args.py | 3 +- main.py | 25 ++ 5 files changed, 364 insertions(+), 186 deletions(-) diff --git a/app/api/schemas_out.py b/app/api/schemas_out.py index 581d71796..8bb34096b 100644 --- a/app/api/schemas_out.py +++ b/app/api/schemas_out.py @@ -92,17 +92,25 @@ class TagsRemove(BaseModel): total_tags: list[str] = Field(default_factory=list) +class AssetScanError(BaseModel): + path: str + message: str + phase: Literal["fast", "slow"] + at: Optional[str] = Field(None, description="ISO timestamp") + + class AssetScanStatus(BaseModel): scan_id: str - root: Literal["models","input","output"] - status: Literal["scheduled","running","completed","failed","cancelled"] + root: Literal["models", "input", "output"] + status: Literal["scheduled", "running", "completed", "failed", "cancelled"] scheduled_at: Optional[str] = None started_at: Optional[str] = None finished_at: Optional[str] = None discovered: int = 0 processed: int = 0 - errors: int = 0 - last_error: Optional[str] = None + slow_queue_total: int = 0 + slow_queue_finished: int = 0 + file_errors: list[AssetScanError] = Field(default_factory=list) class AssetScanStatusResponse(BaseModel): diff --git a/app/assets_manager.py b/app/assets_manager.py index 3d7c040c4..b84b61508 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -43,7 +43,7 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: - if not args.disable_model_processing: + if not args.enable_model_processing: if tags is None: tags = [] try: diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 5bafd6bb7..ccfc8e9e5 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -5,22 +5,22 @@ import uuid import time from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Literal, Optional, Sequence +from typing import Callable, Literal, Optional, Sequence import folder_paths from . import assets_manager from .api import schemas_out from ._assets_helpers import get_comfy_models_folders +from .database.db import create_session +from .database.services import check_fs_asset_exists_quick LOGGER = logging.getLogger(__name__) RootType = Literal["models", "input", "output"] ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") -# We run at most one scan per root; overall max parallelism is therefore 3 -# We also bound per-scan ingestion concurrency to avoid swamping threads/DB -DEFAULT_PER_SCAN_CONCURRENCY = 1 +SLOW_HASH_CONCURRENCY = 1 @dataclass @@ -34,15 +34,25 @@ class ScanProgress: discovered: int = 0 processed: int = 0 - errors: int = 0 - last_error: Optional[str] = None + slow_queue_total: int = 0 + slow_queue_finished: int = 0 + file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"} - # Optional details for diagnostics (e.g., files per bucket) - details: dict[str, int] = field(default_factory=dict) + # Internal diagnostics for logs + _fast_total_seen: int = 0 + _fast_clean: int = 0 + + +@dataclass +class SlowQueueState: + queue: asyncio.Queue + workers: list[asyncio.Task] = field(default_factory=list) + closed: bool = False RUNNING_TASKS: dict[RootType, asyncio.Task] = {} PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} +SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {} def _new_scan_id(root: RootType) -> str: @@ -50,23 +60,13 @@ def _new_scan_id(root: RootType) -> str: def current_statuses() -> schemas_out.AssetScanStatusResponse: - return schemas_out.AssetScanStatusResponse( - scans=[ - schemas_out.AssetScanStatus( - scan_id=s.scan_id, - root=s.root, - status=s.status, - scheduled_at=_ts_to_iso(s.scheduled_at), - started_at=_ts_to_iso(s.started_at), - finished_at=_ts_to_iso(s.finished_at), - discovered=s.discovered, - processed=s.processed, - errors=s.errors, - last_error=s.last_error, - ) - for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT] - ] - ) + scans = [] + for root in ALLOWED_ROOTS: + prog = PROGRESS_BY_ROOT.get(root) + if not prog: + continue + scans.append(_scan_progress_to_scan_status_model(prog)) + return schemas_out.AssetScanStatusResponse(scans=scans) async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: @@ -81,8 +81,6 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes normalized: list[RootType] = [] seen = set() for r in roots or []: - if not isinstance(r, str): - continue rr = r.strip().lower() if rr in ALLOWED_ROOTS and rr not in seen: normalized.append(rr) # type: ignore @@ -98,142 +96,311 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") PROGRESS_BY_ROOT[root] = prog - - task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") - RUNNING_TASKS[root] = task + SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue()) + RUNNING_TASKS[root] = asyncio.create_task( + _pipeline_for_root(root, prog, progress_cb=None), + name=f"asset-scan:{root}", + ) results.append(prog) + return _status_response_for(results) - return schemas_out.AssetScanStatusResponse( - scans=[ - schemas_out.AssetScanStatus( - scan_id=s.scan_id, - root=s.root, - status=s.status, - scheduled_at=_ts_to_iso(s.scheduled_at), - started_at=_ts_to_iso(s.started_at), - finished_at=_ts_to_iso(s.finished_at), - discovered=s.discovered, - processed=s.processed, - errors=s.errors, - last_error=s.last_error, + +async def fast_reconcile_and_kickoff( + roots: Sequence[str] | None = None, + *, + progress_cb: Optional[Callable[[dict], None]] = None, +) -> schemas_out.AssetScanStatusResponse: + """ + Startup helper: do the fast pass now (so we know queue size), + start slow hashing in the background, return immediately. + """ + normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS] + snaps: list[ScanProgress] = [] + + for root in normalized: + if root in RUNNING_TASKS and not RUNNING_TASKS[root].done(): + snaps.append(PROGRESS_BY_ROOT[root]) + continue + + prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") + PROGRESS_BY_ROOT[root] = prog + state = SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state + + prog.status = "running" + prog.started_at = time.time() + try: + await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) + except Exception as e: + _append_error(prog, phase="fast", path="", message=str(e)) + prog.status = "failed" + prog.finished_at = time.time() + LOGGER.exception("Fast reconcile failed for %s", root) + snaps.append(prog) + continue + + _start_slow_workers(root, prog, state, progress_cb=progress_cb) + RUNNING_TASKS[root] = asyncio.create_task( + _await_workers_then_finish(root, prog, state, progress_cb=progress_cb), + name=f"asset-hash:{root}", + ) + snaps.append(prog) + return _status_response_for(snaps) + + +def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse: + return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses]) + + +def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus: + return schemas_out.AssetScanStatus( + scan_id=progress.scan_id, + root=progress.root, + status=progress.status, + scheduled_at=_ts_to_iso(progress.scheduled_at), + started_at=_ts_to_iso(progress.started_at), + finished_at=_ts_to_iso(progress.finished_at), + discovered=progress.discovered, + processed=progress.processed, + slow_queue_total=progress.slow_queue_total, + slow_queue_finished=progress.slow_queue_finished, + file_errors=[ + schemas_out.AssetScanError( + path=e.get("path", ""), + message=e.get("message", ""), + phase=e.get("phase", "slow"), + at=e.get("at"), ) - for s in results - ] + for e in (progress.file_errors or []) + ], ) -async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: - prog.started_at = time.time() +async def _pipeline_for_root( + root: RootType, + prog: ScanProgress, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue()) + SLOW_STATE_BY_ROOT[root] = state + prog.status = "running" + prog.started_at = time.time() + try: - if root == "models": - await _scan_models(prog) - elif root == "input": - base = folder_paths.get_input_directory() - await _scan_directory_tree(base, root, prog) - elif root == "output": - base = folder_paths.get_output_directory() - await _scan_directory_tree(base, root, prog) - else: - raise RuntimeError(f"Unsupported root: {root}") - prog.status = "completed" + await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb) + _start_slow_workers(root, prog, state, progress_cb=progress_cb) + await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb) except asyncio.CancelledError: prog.status = "cancelled" raise except Exception as exc: - LOGGER.exception("Asset scan failed for %s", root) + _append_error(prog, phase="slow", path="", message=str(exc)) prog.status = "failed" - prog.errors += 1 - prog.last_error = str(exc) - finally: prog.finished_at = time.time() + LOGGER.exception("Asset scan failed for %s", root) + finally: RUNNING_TASKS.pop(root, None) -async def _scan_models(prog: ScanProgress) -> None: +async def _fast_reconcile_into_queue( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: """ - Scan all configured model buckets from folder_paths.folder_names_and_paths, - restricted to entries whose base paths lie under folder_paths.models_dir - (per get_comfy_models_folders). We trust those mappings and do not try to - infer anything else here. + Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files, + and queue the rest for slow hashing. """ - targets: list[tuple[str, list[str]]] = get_comfy_models_folders() + if root == "models": + files = _collect_models_files() + preset_discovered = len(files) + files_iter = asyncio.Queue() + for p in files: + await files_iter.put(p) + await files_iter.put(None) # sentinel for our local draining loop + elif root == "input": + base = folder_paths.get_input_directory() + preset_discovered = _count_files_in_tree(os.path.abspath(base)) + files_iter = await _queue_tree_files(base) + elif root == "output": + base = folder_paths.get_output_directory() + preset_discovered = _count_files_in_tree(os.path.abspath(base)) + files_iter = await _queue_tree_files(base) + else: + raise RuntimeError(f"Unsupported root: {root}") - plans: list[str] = [] # absolute file paths to ingest - per_bucket: dict[str, int] = {} + prog.discovered = int(preset_discovered or 0) - for folder_name, bases in targets: + queued = 0 + checked = 0 + clean = 0 + + # Single session for the whole fast pass + async with await create_session() as sess: + while True: + item = await files_iter.get() + files_iter.task_done() + if item is None: + break + + abs_path = item + checked += 1 + + # Stat; skip empty/unreadable + try: + st = os.stat(abs_path, follow_symlinks=True) + if not st.st_size: + continue + size_bytes = int(st.st_size) + mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + except OSError as e: + _append_error(prog, phase="fast", path=abs_path, message=str(e)) + continue + + # Known good -> count as processed immediately + try: + known = await check_fs_asset_exists_quick( + sess, + file_path=abs_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + ) + except Exception as e: + _append_error(prog, phase="fast", path=abs_path, message=str(e)) + known = False + + if known: + clean += 1 + prog.processed += 1 # preserve original semantics + else: + await state.queue.put(abs_path) + queued += 1 + prog.slow_queue_total += 1 + + if progress_cb: + progress_cb({ + "root": root, + "phase": "fast", + "checked": checked, + "clean": clean, + "queued": queued, + "discovered": prog.discovered, + "processed": prog.processed, + }) + + prog._fast_total_seen = checked + prog._fast_clean = clean + + if progress_cb: + progress_cb({ + "root": root, + "phase": "fast", + "checked": checked, + "clean": clean, + "queued": queued, + "discovered": prog.discovered, + "processed": prog.processed, + "done": True, + }) + + state.closed = True + + +def _start_slow_workers( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + if state.workers: + return + + async def _worker(_worker_id: int): + while True: + item = await state.queue.get() + try: + if item is None: + return + try: + await asyncio.to_thread(assets_manager.populate_db_with_asset, item) + except Exception as e: + _append_error(prog, phase="slow", path=item, message=str(e)) + finally: + # Slow queue finished for this item; also counts toward overall processed + prog.slow_queue_finished += 1 + prog.processed += 1 + if progress_cb: + progress_cb({ + "root": root, + "phase": "slow", + "processed": prog.processed, + "slow_queue_finished": prog.slow_queue_finished, + "slow_queue_total": prog.slow_queue_total, + }) + finally: + state.queue.task_done() + + state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)] + + async def _close_when_empty(): + # When the fast phase closed the queue, push sentinels to end workers + while not state.closed: + await asyncio.sleep(0.05) + for _ in range(SLOW_HASH_CONCURRENCY): + await state.queue.put(None) + + asyncio.create_task(_close_when_empty()) + + +async def _await_workers_then_finish( + root: RootType, + prog: ScanProgress, + state: SlowQueueState, + *, + progress_cb: Optional[Callable[[dict], None]], +) -> None: + if state.workers: + await asyncio.gather(*state.workers, return_exceptions=True) + prog.finished_at = time.time() + prog.status = "completed" + if progress_cb: + progress_cb({ + "root": root, + "phase": "slow", + "processed": prog.processed, + "slow_queue_finished": prog.slow_queue_finished, + "slow_queue_total": prog.slow_queue_total, + "done": True, + }) + + +def _collect_models_files() -> list[str]: + """Collect absolute file paths from configured model buckets under models_dir.""" + out: list[str] = [] + for folder_name, bases in get_comfy_models_folders(): rel_files = folder_paths.get_filename_list(folder_name) or [] - count_valid = 0 - for rel_path in rel_files: abs_path = folder_paths.get_full_path(folder_name, rel_path) if not abs_path: continue abs_path = os.path.abspath(abs_path) - - # Extra safety: ensure file is inside one of the allowed base paths + # ensure within allowed bases allowed = False - for base in bases: - base_abs = os.path.abspath(base) + for b in bases: + base_abs = os.path.abspath(b) try: - common = os.path.commonpath([abs_path, base_abs]) - except ValueError: - common = "" # Different drives on Windows - if common == base_abs: - allowed = True - break - if not allowed: - LOGGER.warning("Skipping file outside models base: %s", abs_path) - continue - - try: - if not os.path.getsize(abs_path): - continue # skip empty files - except OSError as e: - LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) - continue - - plans.append(abs_path) - count_valid += 1 - - if count_valid: - per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid - - prog.discovered = len(plans) - for k, v in per_bucket.items(): - prog.details[k] = prog.details.get(k, 0) + v - - if not plans: - LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id) - return - - sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) - tasks: list[asyncio.Task] = [] - - for abs_path in plans: - async def worker(fp_abs: str = abs_path): - try: - # Offload sync ingestion into a thread; populate_db_with_asset - # derives name and tags from the path using _assets_helpers. - await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) - except Exception as e: - prog.errors += 1 - prog.last_error = str(e) - LOGGER.debug("Error ingesting %s: %s", fp_abs, e) - finally: - prog.processed += 1 - sem.release() - - await sem.acquire() - tasks.append(asyncio.create_task(worker())) - - if tasks: - await asyncio.gather(*tasks) - LOGGER.info( - "Model scan %s finished: discovered=%d processed=%d errors=%d", - prog.scan_id, prog.discovered, prog.processed, prog.errors - ) + if os.path.commonpath([abs_path, base_abs]) == base_abs: + allowed = True + break + except Exception: + pass + if allowed: + out.append(abs_path) + return out def _count_files_in_tree(base_abs: str) -> int: @@ -245,60 +412,37 @@ def _count_files_in_tree(base_abs: str) -> int: return total -async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: +async def _queue_tree_files(base_dir: str) -> asyncio.Queue: """ - Generic scanner for input/output roots. We pass only the absolute path to - populate_db_with_asset and let it derive the relative name and tags. + Walk base_dir in a worker thread and return a queue prefilled with all paths, + terminated by a single None sentinel for the draining loop in fast reconcile. """ + q: asyncio.Queue = asyncio.Queue() base_abs = os.path.abspath(base_dir) if not os.path.isdir(base_abs): - LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) - return + await q.put(None) + return q - prog.discovered = _count_files_in_tree(base_abs) + def _walk_list(): + paths: list[str] = [] + for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for name in filenames: + paths.append(os.path.abspath(os.path.join(dirpath, name))) + return paths - sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY) - tasks: list[asyncio.Task] = [] - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): - for name in filenames: - abs_path = os.path.abspath(os.path.join(dirpath, name)) + for p in await asyncio.to_thread(_walk_list): + await q.put(p) + await q.put(None) + return q - # Safety: ensure within base - try: - if os.path.commonpath([abs_path, base_abs]) != base_abs: - LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) - continue - except ValueError: - continue - # Skip empty files and handle stat errors - try: - if not os.path.getsize(abs_path): - continue - except OSError as e: - LOGGER.warning("Could not stat %s: %s – skipping", abs_path, e) - continue - - async def worker(fp_abs: str = abs_path): - try: - await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) - except Exception as e: - prog.errors += 1 - prog.last_error = str(e) - finally: - prog.processed += 1 - sem.release() - - await sem.acquire() - tasks.append(asyncio.create_task(worker())) - - if tasks: - await asyncio.gather(*tasks) - - LOGGER.info( - "%s scan %s finished: discovered=%d processed=%d errors=%d", - root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors - ) +def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None: + prog.file_errors.append({ + "path": path, + "message": message, + "phase": phase, + "at": _ts_to_iso(time.time()), + }) def _ts_to_iso(ts: Optional[float]) -> Optional[str]: diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 7de4adbdc..5e301b505 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -212,7 +212,8 @@ database_default_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") ) parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") -parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.") +parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.") +parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") if comfy.options.args_parsing: args = parser.parse_args() diff --git a/main.py b/main.py index 557961d40..017f88a63 100644 --- a/main.py +++ b/main.py @@ -279,10 +279,35 @@ def cleanup_temp(): shutil.rmtree(temp_dir, ignore_errors=True) async def setup_database(): + def _console_cb(e: dict): + root = e.get("root") + phase = e.get("phase") + if phase == "fast": + if e.get("done"): + logging.info( + f"[assets][{root}] fast done: processed={e['processed']}/{e['discovered']} queued={e['queued']}" + ) + elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress + logging.info(f"[assets][{root}] fast progress: processed={e['processed']}/{e['discovered']}" + ) + elif phase == "slow": + if e.get("done"): + logging.info( + f"[assets][{root}] slow done: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" + ) + else: + logging.info( + f"[assets][{root}] slow progress: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}" + ) + try: from app.database.db import init_db_engine, dependencies_available if dependencies_available(): await init_db_engine() + if not args.disable_assets_autoscan: + from app import assets_scanner + + await assets_scanner.fast_reconcile_and_kickoff(progress_cb=_console_cb) except Exception as e: logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")