added Assets Autoscan feature

This commit is contained in:
bigcat88 2025-09-05 17:46:09 +03:00
parent bf8363ec87
commit ce270ba090
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
5 changed files with 364 additions and 186 deletions

View File

@ -92,6 +92,13 @@ class TagsRemove(BaseModel):
total_tags: list[str] = Field(default_factory=list)
class AssetScanError(BaseModel):
path: str
message: str
phase: Literal["fast", "slow"]
at: Optional[str] = Field(None, description="ISO timestamp")
class AssetScanStatus(BaseModel):
scan_id: str
root: Literal["models", "input", "output"]
@ -101,8 +108,9 @@ class AssetScanStatus(BaseModel):
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
errors: int = 0
last_error: Optional[str] = None
slow_queue_total: int = 0
slow_queue_finished: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)
class AssetScanStatusResponse(BaseModel):

View File

@ -43,7 +43,7 @@ async def asset_exists(*, asset_hash: str) -> bool:
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if not args.disable_model_processing:
if not args.enable_model_processing:
if tags is None:
tags = []
try:

View File

@ -5,22 +5,22 @@ import uuid
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Literal, Optional, Sequence
from typing import Callable, Literal, Optional, Sequence
import folder_paths
from . import assets_manager
from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders
from .database.db import create_session
from .database.services import check_fs_asset_exists_quick
LOGGER = logging.getLogger(__name__)
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
# We run at most one scan per root; overall max parallelism is therefore 3
# We also bound per-scan ingestion concurrency to avoid swamping threads/DB
DEFAULT_PER_SCAN_CONCURRENCY = 1
SLOW_HASH_CONCURRENCY = 1
@dataclass
@ -34,15 +34,25 @@ class ScanProgress:
discovered: int = 0
processed: int = 0
errors: int = 0
last_error: Optional[str] = None
slow_queue_total: int = 0
slow_queue_finished: int = 0
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
# Optional details for diagnostics (e.g., files per bucket)
details: dict[str, int] = field(default_factory=dict)
# Internal diagnostics for logs
_fast_total_seen: int = 0
_fast_clean: int = 0
@dataclass
class SlowQueueState:
queue: asyncio.Queue
workers: list[asyncio.Task] = field(default_factory=list)
closed: bool = False
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
def _new_scan_id(root: RootType) -> str:
@ -50,23 +60,13 @@ def _new_scan_id(root: RootType) -> str:
def current_statuses() -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
scan_id=s.scan_id,
root=s.root,
status=s.status,
scheduled_at=_ts_to_iso(s.scheduled_at),
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in [PROGRESS_BY_ROOT[r] for r in ALLOWED_ROOTS if r in PROGRESS_BY_ROOT]
]
)
scans = []
for root in ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
scans.append(_scan_progress_to_scan_status_model(prog))
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
@ -81,8 +81,6 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
normalized: list[RootType] = []
seen = set()
for r in roots or []:
if not isinstance(r, str):
continue
rr = r.strip().lower()
if rr in ALLOWED_ROOTS and rr not in seen:
normalized.append(rr) # type: ignore
@ -98,142 +96,311 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
task = asyncio.create_task(_run_scan_for_root(root, prog), name=f"asset-scan:{root}")
RUNNING_TASKS[root] = task
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
RUNNING_TASKS[root] = asyncio.create_task(
_pipeline_for_root(root, prog, progress_cb=None),
name=f"asset-scan:{root}",
)
results.append(prog)
return schemas_out.AssetScanStatusResponse(
scans=[
schemas_out.AssetScanStatus(
scan_id=s.scan_id,
root=s.root,
status=s.status,
scheduled_at=_ts_to_iso(s.scheduled_at),
started_at=_ts_to_iso(s.started_at),
finished_at=_ts_to_iso(s.finished_at),
discovered=s.discovered,
processed=s.processed,
errors=s.errors,
last_error=s.last_error,
)
for s in results
]
)
return _status_response_for(results)
async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.started_at = time.time()
async def fast_reconcile_and_kickoff(
roots: Sequence[str] | None = None,
*,
progress_cb: Optional[Callable[[dict], None]] = None,
) -> schemas_out.AssetScanStatusResponse:
"""
Startup helper: do the fast pass now (so we know queue size),
start slow hashing in the background, return immediately.
"""
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
snaps: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
snaps.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
try:
if root == "models":
await _scan_models(prog)
elif root == "input":
base = folder_paths.get_input_directory()
await _scan_directory_tree(base, root, prog)
elif root == "output":
base = folder_paths.get_output_directory()
await _scan_directory_tree(base, root, prog)
else:
raise RuntimeError(f"Unsupported root: {root}")
prog.status = "completed"
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
except Exception as e:
_append_error(prog, phase="fast", path="", message=str(e))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Fast reconcile failed for %s", root)
snaps.append(prog)
continue
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
RUNNING_TASKS[root] = asyncio.create_task(
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
name=f"asset-hash:{root}",
)
snaps.append(prog)
return _status_response_for(snaps)
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=[_scan_progress_to_scan_status_model(p) for p in progresses])
def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.AssetScanStatus:
return schemas_out.AssetScanStatus(
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=_ts_to_iso(progress.scheduled_at),
started_at=_ts_to_iso(progress.started_at),
finished_at=_ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
slow_queue_total=progress.slow_queue_total,
slow_queue_finished=progress.slow_queue_finished,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
phase=e.get("phase", "slow"),
at=e.get("at"),
)
for e in (progress.file_errors or [])
],
)
async def _pipeline_for_root(
root: RootType,
prog: ScanProgress,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
try:
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
LOGGER.exception("Asset scan failed for %s", root)
_append_error(prog, phase="slow", path="", message=str(exc))
prog.status = "failed"
prog.errors += 1
prog.last_error = str(exc)
finally:
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
finally:
RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None:
async def _fast_reconcile_into_queue(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
"""
Scan all configured model buckets from folder_paths.folder_names_and_paths,
restricted to entries whose base paths lie under folder_paths.models_dir
(per get_comfy_models_folders). We trust those mappings and do not try to
infer anything else here.
Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
and queue the rest for slow hashing.
"""
targets: list[tuple[str, list[str]]] = get_comfy_models_folders()
if root == "models":
files = _collect_models_files()
preset_discovered = len(files)
files_iter = asyncio.Queue()
for p in files:
await files_iter.put(p)
await files_iter.put(None) # sentinel for our local draining loop
elif root == "input":
base = folder_paths.get_input_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
elif root == "output":
base = folder_paths.get_output_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
files_iter = await _queue_tree_files(base)
else:
raise RuntimeError(f"Unsupported root: {root}")
plans: list[str] = [] # absolute file paths to ingest
per_bucket: dict[str, int] = {}
prog.discovered = int(preset_discovered or 0)
for folder_name, bases in targets:
queued = 0
checked = 0
clean = 0
# Single session for the whole fast pass
async with await create_session() as sess:
while True:
item = await files_iter.get()
files_iter.task_done()
if item is None:
break
abs_path = item
checked += 1
# Stat; skip empty/unreadable
try:
st = os.stat(abs_path, follow_symlinks=True)
if not st.st_size:
continue
size_bytes = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
except OSError as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
continue
# Known good -> count as processed immediately
try:
known = await check_fs_asset_exists_quick(
sess,
file_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
)
except Exception as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
known = False
if known:
clean += 1
prog.processed += 1 # preserve original semantics
else:
await state.queue.put(abs_path)
queued += 1
prog.slow_queue_total += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
})
prog._fast_total_seen = checked
prog._fast_clean = clean
if progress_cb:
progress_cb({
"root": root,
"phase": "fast",
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
"processed": prog.processed,
"done": True,
})
state.closed = True
def _start_slow_workers(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
return
async def _worker(_worker_id: int):
while True:
item = await state.queue.get()
try:
if item is None:
return
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
except Exception as e:
_append_error(prog, phase="slow", path=item, message=str(e))
finally:
# Slow queue finished for this item; also counts toward overall processed
prog.slow_queue_finished += 1
prog.processed += 1
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
})
finally:
state.queue.task_done()
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
async def _close_when_empty():
# When the fast phase closed the queue, push sentinels to end workers
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_empty())
async def _await_workers_then_finish(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[dict], None]],
) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
prog.finished_at = time.time()
prog.status = "completed"
if progress_cb:
progress_cb({
"root": root,
"phase": "slow",
"processed": prog.processed,
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
"done": True,
})
def _collect_models_files() -> list[str]:
"""Collect absolute file paths from configured model buckets under models_dir."""
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
count_valid = 0
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
# Extra safety: ensure file is inside one of the allowed base paths
# ensure within allowed bases
allowed = False
for base in bases:
base_abs = os.path.abspath(base)
for b in bases:
base_abs = os.path.abspath(b)
try:
common = os.path.commonpath([abs_path, base_abs])
except ValueError:
common = "" # Different drives on Windows
if common == base_abs:
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if not allowed:
LOGGER.warning("Skipping file outside models base: %s", abs_path)
continue
try:
if not os.path.getsize(abs_path):
continue # skip empty files
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
plans.append(abs_path)
count_valid += 1
if count_valid:
per_bucket[folder_name] = per_bucket.get(folder_name, 0) + count_valid
prog.discovered = len(plans)
for k, v in per_bucket.items():
prog.details[k] = prog.details.get(k, 0) + v
if not plans:
LOGGER.info("Model scan %s: nothing to ingest", prog.scan_id)
return
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
for abs_path in plans:
async def worker(fp_abs: str = abs_path):
try:
# Offload sync ingestion into a thread; populate_db_with_asset
# derives name and tags from the path using _assets_helpers.
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
LOGGER.debug("Error ingesting %s: %s", fp_abs, e)
finally:
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"Model scan %s finished: discovered=%d processed=%d errors=%d",
prog.scan_id, prog.discovered, prog.processed, prog.errors
)
except Exception:
pass
if allowed:
out.append(abs_path)
return out
def _count_files_in_tree(base_abs: str) -> int:
@ -245,60 +412,37 @@ def _count_files_in_tree(base_abs: str) -> int:
return total
async def _scan_directory_tree(base_dir: str, root: RootType, prog: ScanProgress) -> None:
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
"""
Generic scanner for input/output roots. We pass only the absolute path to
populate_db_with_asset and let it derive the relative name and tags.
Walk base_dir in a worker thread and return a queue prefilled with all paths,
terminated by a single None sentinel for the draining loop in fast reconcile.
"""
q: asyncio.Queue = asyncio.Queue()
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
LOGGER.info("Scan root %s skipped: base directory missing: %s", root, base_abs)
return
await q.put(None)
return q
prog.discovered = _count_files_in_tree(base_abs)
sem = asyncio.Semaphore(DEFAULT_PER_SCAN_CONCURRENCY)
tasks: list[asyncio.Task] = []
def _walk_list():
paths: list[str] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
abs_path = os.path.abspath(os.path.join(dirpath, name))
paths.append(os.path.abspath(os.path.join(dirpath, name)))
return paths
# Safety: ensure within base
try:
if os.path.commonpath([abs_path, base_abs]) != base_abs:
LOGGER.warning("Skipping path outside root %s: %s", root, abs_path)
continue
except ValueError:
continue
for p in await asyncio.to_thread(_walk_list):
await q.put(p)
await q.put(None)
return q
# Skip empty files and handle stat errors
try:
if not os.path.getsize(abs_path):
continue
except OSError as e:
LOGGER.warning("Could not stat %s: %s skipping", abs_path, e)
continue
async def worker(fp_abs: str = abs_path):
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, fp_abs)
except Exception as e:
prog.errors += 1
prog.last_error = str(e)
finally:
prog.processed += 1
sem.release()
await sem.acquire()
tasks.append(asyncio.create_task(worker()))
if tasks:
await asyncio.gather(*tasks)
LOGGER.info(
"%s scan %s finished: discovered=%d processed=%d errors=%d",
root.capitalize(), prog.scan_id, prog.discovered, prog.processed, prog.errors
)
def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"phase": phase,
"at": _ts_to_iso(time.time()),
})
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:

View File

@ -212,7 +212,8 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--disable-model-processing", action="store_true", help="Disable automatic processing of the model file, such as calculating hashes and populating the database.")
parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:
args = parser.parse_args()

25
main.py
View File

@ -279,10 +279,35 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
async def setup_database():
def _console_cb(e: dict):
root = e.get("root")
phase = e.get("phase")
if phase == "fast":
if e.get("done"):
logging.info(
f"[assets][{root}] fast done: processed={e['processed']}/{e['discovered']} queued={e['queued']}"
)
elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress
logging.info(f"[assets][{root}] fast progress: processed={e['processed']}/{e['discovered']}"
)
elif phase == "slow":
if e.get("done"):
logging.info(
f"[assets][{root}] slow done: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}"
)
else:
logging.info(
f"[assets][{root}] slow progress: {e.get('slow_queue_finished', 0)}/{e.get('slow_queue_total', 0)}"
)
try:
from app.database.db import init_db_engine, dependencies_available
if dependencies_available():
await init_db_engine()
if not args.disable_assets_autoscan:
from app import assets_scanner
await assets_scanner.fast_reconcile_and_kickoff(progress_cb=_console_cb)
except Exception as e:
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")