added Assets Autoscan feature

This commit is contained in:
bigcat88 2025-09-05 17:46:09 +03:00
parent bf8363ec87
commit ce270ba090
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
5 changed files with 364 additions and 186 deletions

View File

@ -92,17 +92,25 @@ class TagsRemove(BaseModel):
total_tags: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
phase: Literal["fast", "slow"]
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel): class AssetScanStatus(BaseModel):
scan_id: str scan_id: str
root: Literal["models","input","output"] root: Literal["models", "input", "output"]
status: Literal["scheduled","running","completed","failed","cancelled"] status: Literal["scheduled", "running", "completed", "failed", "cancelled"]
scheduled_at: Optional[str] = None scheduled_at: Optional[str] = None
started_at: Optional[str] = None started_at: Optional[str] = None
finished_at: Optional[str] = None finished_at: Optional[str] = None
discovered: int = 0 discovered: int = 0
processed: int = 0 processed: int = 0
errors: int = 0 slow_queue_total: int = 0
last_error: Optional[str] = None slow_queue_finished: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel): class AssetScanStatusResponse(BaseModel):

View File

@ -43,7 +43,7 @@ async def asset_exists(*, asset_hash: str) -> bool:
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None: def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if not args.disable_model_processing: if not args.enable_model_processing:
if tags is None: if tags is None:
tags = [] tags = []
try: try:

View File

@ -5,22 +5,22 @@ import uuid
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Literal, Optional, Sequence from typing import Callable, Literal, Optional, Sequence
import folder_paths import folder_paths
from . import assets_manager from . import assets_manager
from .api import schemas_out from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders from ._assets_helpers import get_comfy_models_folders
from .database.db import create_session
from .database.services import check_fs_asset_exists_quick
LOGGER = logging.getLogger(__name__) LOGGER = logging.getLogger(__name__)
RootType = Literal["models", "input", "output"] RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output") ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
# We run at most one scan per root; overall max parallelism is therefore 3 SLOW_HASH_CONCURRENCY = 1
# We also bound per-scan ingestion concurrency to avoid swamping threads/DB
DEFAULT_PER_SCAN_CONCURRENCY = 1
@dataclass @dataclass
@ -34,15 +34,25 @@ class ScanProgress:
discovered: int = 0 discovered: int = 0
processed: int = 0 processed: int = 0
errors: int = 0 slow_queue_total: int = 0
last_error: Optional[str] = None slow_queue_finished: int = 0
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
# Optional details for diagnostics (e.g., files per bucket) # Internal diagnostics for logs
details: dict[str, int] = field(default_factory=dict) _fast_total_seen: int = 0
_fast_clean: int = 0
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[RootType, asyncio.Task] = {} RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {} PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
def _new_scan_id(root: RootType) -> str: def _new_scan_id(root: RootType) -> str:
@ -50,23 +60,13 @@ def _new_scan_id(root: RootType) -> str:
def current_statuses() -> schemas_out.AssetScanStatusResponse: def current_statuses() -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse( scans = []
scans=[ for root in ALLOWED_ROOTS:
schemas_out.AssetScanStatus( prog = PROGRESS_BY_ROOT.get(root)
scan_id=s.scan_id, if not prog:
root=s.root, continue
status=s.status, scans.append(_scan_progress_to_scan_status_model(prog))
scheduled_at=_ts_to_iso(s.scheduled_at), return schemas_out.AssetScanStatusResponse(scans=scans)
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
]
)
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse: async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
@ -81,8 +81,6 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
normalized: list[RootType] = [] normalized: list[RootType] = []
seen = set() seen = set()
for r in roots or []: for r in roots or []:
if not isinstance(r, str):
continue
rr = r.strip().lower() rr = r.strip().lower()
if rr in ALLOWED_ROOTS and rr not in seen: if rr in ALLOWED_ROOTS and rr not in seen:
normalized.append(rr) # type: ignore normalized.append(rr) # type: ignore
@ -98,142 +96,311 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled") prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog PROGRESS_BY_ROOT[root] = prog
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}") RUNNING_TASKS[root] = asyncio.create_task(
RUNNING_TASKS[root] = task _pipeline_for_root(root, prog, progress_cb=None),
name=f"asset-scan:{root}",
)
results.append(prog) results.append(prog)
return _status_response_for(results)
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
scan_id=s.scan_id,
root=s.root,
status=s.status,
scheduled_at=_ts_to_iso(s.scheduled_at),
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in results
]
)
async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: async def fast_reconcile_and_kickoff(
prog.started_at = time.time() roots: Sequence[str] | None = None,
*,
progress_cb: Optional[Callable[[dict], None]] = None,
) -> schemas_out.AssetScanStatusResponse:
"""
Startup helper: do the fast pass now (so we know queue size),
start slow hashing in the background, return immediately.
"""
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
snaps: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
snaps.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running" prog.status = "running"
prog.started_at = time.time()
try: try:
if root == "models": await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
await _scan_models(prog) except Exception as e:
elif root == "input": _append_error(prog, phase="fast", path="", message=str(e))
base = folder_paths.get_input_directory() prog.status = "failed"
await _scan_directory_tree(base, root, prog) prog.finished_at = time.time()
elif root == "output": LOGGER.exception("Fast reconcile failed for %s", root)
base = folder_paths.get_output_directory() snaps.append(prog)
await _scan_directory_tree(base, root, prog) continue
else:
raise RuntimeError(f"Unsupported root: {root}") _start_slow_workers(root, prog, state, progress_cb=progress_cb)
prog.status = "completed" RUNNING_TASKS[root] = asyncio.create_task(
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
name=f"asset-hash:{root}",
)
snaps.append(prog)
return _status_response_for(snaps)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=_ts_to_iso(progress.scheduled_at),
started_at=_ts_to_iso(progress.started_at),
finished_at=_ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
slow_queue_total=progress.slow_queue_total,
slow_queue_finished=progress.slow_queue_finished,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
phase=e.get("phase", "slow"),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _pipeline_for_root(
root: RootType,
prog: ScanProgress,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
try:
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
except asyncio.CancelledError: except asyncio.CancelledError:
prog.status = "cancelled" prog.status = "cancelled"
raise raise
except Exception as exc: except Exception as exc:
LOGGER.exception("Asset scan failed for %s", root) _append_error(prog, phase="slow", path="", message=str(exc))
prog.status = "failed" prog.status = "failed"
prog.errors += 1
prog.last_error = str(exc)
finally:
prog.finished_at = time.time() prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None) RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None: async def _fast_reconcile_into_queue(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
""" """
Scan all configured model buckets from folder_paths.folder_names_and_paths, Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
restricted to entries whose base paths lie under folder_paths.models_dir and queue the rest for slow hashing.
(per get_comfy_models_folders). We trust those mappings and do not try to
infer anything else here.
""" """
targets: list[tuple[str, list[str]]] = get_comfy_models_folders() if root == "models":
files = _collect_models_files()
preset_discovered = len(files)
files_iter = asyncio.Queue()
for p in files:
await files_iter.put(p)
await files_iter.put(None) # sentinel for our local draining loop
elif root == "input":
base = folder_paths.get_input_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
elif root == "output":
base = folder_paths.get_output_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
else:
raise RuntimeError(f"Unsupported root: {root}")
plans: list[str] = [] # absolute file paths to ingest prog.discovered = int(preset_discovered or 0)
per_bucket: dict[str, int] = {}
for folder_name, bases in targets: queued = 0
checked = 0
clean = 0
# Single session for the whole fast pass
async with await create_session() as sess:
while True:
item = await files_iter.get()
files_iter.task_done()
if item is None:
break
abs_path = item
checked += 1
# Stat; skip empty/unreadable
try:
st = os.stat(abs_path, follow_symlinks=True)
if not st.st_size:
continue
size_bytes = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
except OSError as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
continue
# Known good -> count as processed immediately
try:
known = await check_fs_asset_exists_quick(
sess,
file_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
)
except Exception as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
known = False
if known:
clean += 1
prog.processed += 1 # preserve original semantics
else:
await state.queue.put(abs_path)
queued += 1
prog.slow_queue_total += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
})
prog._fast_total_seen = checked
prog._fast_clean = clean
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
"done": True,
})
state.closed = True
def _start_slow_workers(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
return
async def _worker(_worker_id: int):
while True:
item = await state.queue.get()
try:
if item is None:
return
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
except Exception as e:
_append_error(prog, phase="slow", path=item, message=str(e))
finally:
# Slow queue finished for this item; also counts toward overall processed
prog.slow_queue_finished += 1
prog.processed += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
})
finally:
state.queue.task_done()
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
async def _close_when_empty():
# When the fast phase closed the queue, push sentinels to end workers
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_empty())
async def _await_workers_then_finish(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
prog.finished_at = time.time()
prog.status = "completed"
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
"done": True,
})
def _collect_models_files() -> list[str]:
"""Collect absolute file paths from configured model buckets under models_dir."""
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or [] rel_files = folder_paths.get_filename_list(folder_name) or []
count_valid = 0
for rel_path in rel_files: for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path) abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path: if not abs_path:
continue continue
abs_path = os.path.abspath(abs_path) abs_path = os.path.abspath(abs_path)
# ensure within allowed bases
# Extra safety: ensure file is inside one of the allowed base paths
allowed = False allowed = False
for base in bases: for b in bases:
base_abs = os.path.abspath(base) base_abs = os.path.abspath(b)
try: try:
common = os.path.commonpath([abs_path, base_abs]) if os.path.commonpath([abs_path, base_abs]) == base_abs:
except ValueError:
common = "" # Different drives on Windows
if common == base_abs:
allowed = True allowed = True
break break
if not allowed: except Exception:
LOGGER.warning("Skipping file outside models base: %s", abs_path) pass
continue if allowed:
out.append(abs_path)
try: return out
if not os.path.getsize(abs_path):
continue # skip empty files
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
plans.append(abs_path)
count_valid += 1
if count_valid:
per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid
prog.discovered = len(plans)
for k, v in per_bucket.items():
prog.details[k] = prog.details.get(k, 0) + v
if not plans:
LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id)
return
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for abs_path in plans:
async def worker(fp_abs: str = abs_path):
try:
# Offload sync ingestion into a thread; populate_db_with_asset
# derives name and tags from the path using _assets_helpers.
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
LOGGER.debug("Error ingesting %s: %s", fp_abs, e)
finally:
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"Model scan %s finished: discovered=%d processed=%d errors=%d",
prog.scan_id, prog.discovered, prog.processed, prog.errors
)
def _count_files_in_tree(base_abs: str) -> int: def _count_files_in_tree(base_abs: str) -> int:
@ -245,60 +412,37 @@ def _count_files_in_tree(base_abs: str) -> int:
return total return total
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None: async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
""" """
Generic scanner for input/output roots. We pass only the absolute path to Walk base_dir in a worker thread and return a queue prefilled with all paths,
populate_db_with_asset and let it derive the relative name and tags. terminated by a single None sentinel for the draining loop in fast reconcile.
""" """
q: asyncio.Queue = asyncio.Queue()
base_abs = os.path.abspath(base_dir) base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs): if not os.path.isdir(base_abs):
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs) await q.put(None)
return return q
prog.discovered = _count_files_in_tree(base_abs) def _walk_list():
paths: list[str] = []
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames: for name in filenames:
abs_path = os.path.abspath(os.path.join(dirpath, name)) paths.append(os.path.abspath(os.path.join(dirpath, name)))
return paths
# Safety: ensure within base for p in await asyncio.to_thread(_walk_list):
try: await q.put(p)
if os.path.commonpath([abs_path, base_abs]) != base_abs: await q.put(None)
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path) return q
continue
except ValueError:
continue
# Skip empty files and handle stat errors
try:
if not os.path.getsize(abs_path):
continue
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
async def worker(fp_abs: str = abs_path): def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
try: prog.file_errors.append({
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs) "path": path,
except Exception as e: "message": message,
prog.errors += 1 "phase": phase,
prog.last_error = str(e) "at": _ts_to_iso(time.time()),
finally: })
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"%s scan %s finished: discovered=%d processed=%d errors=%d",
root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors
)
def _ts_to_iso(ts: Optional[float]) -> Optional[str]: def _ts_to_iso(ts: Optional[float]) -> Optional[str]:

View File

@ -212,7 +212,8 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
) )
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.") parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()

25
main.py
View File

@ -279,10 +279,35 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(temp_dir, ignore_errors=True)
async def setup_database(): async def setup_database():
def _console_cb(e: dict):
root = e.get("root")
phase = e.get("phase")
if phase == "fast":
if e.get("done"):
logging.info(
f"[assets][{root}] fast done: processed={e['processed']}/{e['discovered']} queued={e['queued']}"
)
elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress
logging.info(f"[assets][{root}] fast progress: processed={e['processed']}/{e['discovered']}"
)
elif phase == "slow":
if e.get("done"):
logging.info(
f"[assets][{root}] slow done: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}"
)
else:
logging.info(
f"[assets][{root}] slow progress: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}"
)
try: try:
from app.database.db import init_db_engine, dependencies_available from app.database.db import init_db_engine, dependencies_available
if dependencies_available(): if dependencies_available():
await init_db_engine() await init_db_engine()
if not args.disable_assets_autoscan:
from app import assets_scanner
await assets_scanner.fast_reconcile_and_kickoff(progress_cb=_console_cb)
except Exception as e: except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")