ComfyUI/app/assets_scanner.py

314 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
import os
import uuid
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional, Sequence
import folder_paths
from . import assets_manager
from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders
LOGGER = logging.getLogger(__name__)
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
# We run at most one scan per root; overall max parallelism is therefore 3
# We also bound per-scan ingestion concurrency to avoid swamping threads/DB
DEFAULT_PER_SCAN_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
errors: int = 0
last_error: Optional[str] = None
# Optional details for diagnostics (e.g., files per bucket)
details: dict[str, int] = field(default_factory=dict)
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
def _new_scan_id(root: RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def current_statuses() -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
scan_id=s.scan_id,
root=s.root,
status=s.status,
scheduled_at=_ts_to_iso(s.scheduled_at),
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
]
)
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
"""Schedule scans for the provided roots; returns progress snapshots.
Rules:
- Only roots in {models, input, output} are accepted.
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
- Otherwise a new task is created and started immediately.
- Files with zero size are skipped.
"""
normalized: list[RootType] = []
seen = set()
for r in roots or []:
if not isinstance(r, str):
continue
rr = r.strip().lower()
if rr in ALLOWED_ROOTS and rr not in seen:
normalized.append(rr) # type: ignore
seen.add(rr)
if not normalized:
normalized = list(ALLOWED_ROOTS) # schedule all by default
results: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
RUNNING_TASKS[root] = task
results.append(prog)
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
scan_id=s.scan_id,
root=s.root,
status=s.status,
scheduled_at=_ts_to_iso(s.scheduled_at),
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in results
]
)
async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.started_at = time.time()
prog.status = "running"
try:
if root == "models":
await _scan_models(prog)
elif root == "input":
base = folder_paths.get_input_directory()
await _scan_directory_tree(base, root, prog)
elif root == "output":
base = folder_paths.get_output_directory()
await _scan_directory_tree(base, root, prog)
else:
raise RuntimeError(f"Unsupported root: {root}")
prog.status = "completed"
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
LOGGER.exception("Asset scan failed for %s", root)
prog.status = "failed"
prog.errors += 1
prog.last_error = str(exc)
finally:
prog.finished_at = time.time()
t = RUNNING_TASKS.get(root)
if t and t.done():
RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None:
"""
Scan all configured model buckets from folder_paths.folder_names_and_paths,
restricted to entries whose base paths lie under folder_paths.models_dir
(per get_comfy_models_folders). We trust those mappings and do not try to
infer anything else here.
"""
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
plans: list[str] = [] # absolute file paths to ingest
per_bucket: dict[str, int] = {}
for folder_name, bases in targets:
rel_files = folder_paths.get_filename_list(folder_name) or []
count_valid = 0
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
# Extra safety: ensure file is inside one of the allowed base paths
allowed = False
for base in bases:
base_abs = os.path.abspath(base)
try:
common = os.path.commonpath([abs_path, base_abs])
except ValueError:
common = "" # Different drives on Windows
if common == base_abs:
allowed = True
break
if not allowed:
LOGGER.warning("Skipping file outside models base: %s", abs_path)
continue
try:
if not os.path.getsize(abs_path):
continue # skip empty files
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
plans.append(abs_path)
count_valid += 1
if count_valid:
per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid
prog.discovered = len(plans)
for k, v in per_bucket.items():
prog.details[k] = prog.details.get(k, 0) + v
if not plans:
LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id)
return
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for abs_path in plans:
async def worker(fp_abs: str = abs_path):
try:
# Offload sync ingestion into a thread; populate_db_with_asset
# derives name and tags from the path using _assets_helpers.
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
LOGGER.debug("Error ingesting %s: %s", fp_abs, e)
finally:
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"Model scan %s finished: discovered=%d processed=%d errors=%d",
prog.scan_id, prog.discovered, prog.processed, prog.errors
)
def _count_files_in_tree(base_abs: str) -> int:
if not os.path.isdir(base_abs):
return 0
total = 0
for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
total += len(filenames)
return total
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
"""
Generic scanner for input/output roots. We pass only the absolute path to
populate_db_with_asset and let it derive the relative name and tags.
"""
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
return
prog.discovered = _count_files_in_tree(base_abs)
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
abs_path = os.path.abspath(os.path.join(dirpath, name))
# Safety: ensure within base
try:
if os.path.commonpath([abs_path, base_abs]) != base_abs:
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
continue
except ValueError:
continue
# Skip empty files and handle stat errors
try:
if not os.path.getsize(abs_path):
continue
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
async def worker(fp_abs: str = abs_path):
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
finally:
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"%s scan %s finished: discovered=%d processed=%d errors=%d",
root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors
)
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None