global refactoring; add support for Assets without the computed hash

This commit is contained in:
bigcat88 2025-09-12 18:14:52 +03:00
parent 934377ac1e
commit bb9ed04758
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
27 changed files with 2380 additions and 1930 deletions

View File

@ -16,33 +16,44 @@ depends_on = None
def upgrade() -> None:
# ASSETS: content identity (deduplicated by hash)
# ASSETS: content identity
op.create_table(
"assets",
sa.Column("hash", sa.String(length=256), primary_key=True),
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("hash", sa.String(length=256), nullable=True),
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
if op.get_bind().dialect.name == "postgresql":
op.create_index(
"uq_assets_hash_not_null",
"assets",
["hash"],
unique=True,
postgresql_where=sa.text("hash IS NOT NULL"),
)
else:
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
# ASSETS_INFO: user-visible references (mutable metadata)
# ASSETS_INFO: user-visible references
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
@ -69,18 +80,19 @@ def upgrade() -> None:
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
@ -99,7 +111,7 @@ def upgrade() -> None:
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# Tags vocabulary for models
# Tags vocabulary
tags_table = sa.table(
"tags",
sa.column("name", sa.String(length=512)),
@ -108,12 +120,10 @@ def upgrade() -> None:
op.bulk_insert(
tags_table,
[
# Root folder tags
{"name": "models", "tag_type": "system"},
{"name": "input", "tag_type": "system"},
{"name": "output", "tag_type": "system"},
# Core tags
{"name": "configs", "tag_type": "system"},
{"name": "checkpoints", "tag_type": "system"},
{"name": "loras", "tag_type": "system"},
@ -132,12 +142,11 @@ def upgrade() -> None:
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
# Extra basic tags
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
# Special tags
{"name": "missing", "tag_type": "system"},
{"name": "rescan", "tag_type": "system"},
],
)
@ -149,8 +158,9 @@ def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
@ -160,14 +170,18 @@ def downgrade() -> None:
op.drop_index("ix_tags_tag_type", table_name="tags")
op.drop_table("tags")
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
if op.get_bind().dialect.name == "postgresql":
op.drop_index("uq_assets_hash_not_null", table_name="assets")
else:
op.drop_index("uq_assets_hash", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

View File

@ -1,5 +1,5 @@
from .assets_scanner import sync_seed_assets
from .database.db import init_db_engine
from .assets_scanner import start_background_assets_scan
from .api.assets_routes import register_assets_system
__all__ = ["init_db_engine", "start_background_assets_scan"]
__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"]

View File

@ -1,12 +1,13 @@
import contextlib
import os
import uuid
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, Literal, Sequence
import sqlalchemy as sa
from typing import Literal, Optional, Sequence
import folder_paths
from .database.models import AssetInfo
from .api import schemas_in
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
@ -139,14 +140,6 @@ def ensure_within_base(candidate: str, base: str) -> None:
raise ValueError("invalid destination path")
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def compute_model_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
@ -172,3 +165,61 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]:
return None
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside) # normalize to POSIX style for portability
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
def new_scan_id(root: schemas_in.RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out

View File

@ -1,7 +1,7 @@
import contextlib
import os
import uuid
import urllib.parse
import uuid
from typing import Optional
from aiohttp import web
@ -12,7 +12,6 @@ import folder_paths
from .. import assets_manager, assets_scanner, user_manager
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
UserManager: Optional[user_manager.UserManager] = None
@ -272,6 +271,7 @@ async def upload_asset(request: web.Request) -> web.Response:
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=owner_id,
expected_asset_hash=spec.hash,
)
status = 201 if created.created_new else 200
return web.json_response(created.model_dump(mode="json"), status=status)
@ -332,6 +332,29 @@ async def update_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
async def set_asset_preview(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))
try:
body = schemas_in.SetPreviewBody.model_validate(await request.json())
except ValidationError as ve:
return _validation_error_response("INVALID_BODY", ve)
except Exception:
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
try:
result = await assets_manager.set_asset_preview(
asset_info_id=asset_info_id,
preview_asset_id=body.preview_id,
owner_id=UserManager.get_request_user_id(request),
)
except (PermissionError, ValueError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id = str(uuid.UUID(request.match_info["id"]))

View File

@ -1,7 +1,15 @@
import json
import uuid
from typing import Any, Literal, Optional
from typing import Any, Optional, Literal
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint
from pydantic import (
BaseModel,
ConfigDict,
Field,
conint,
field_validator,
model_validator,
)
class ListAssetsQuery(BaseModel):
@ -148,30 +156,12 @@ class TagsRemove(TagsAdd):
pass
class ScheduleAssetScanBody(BaseModel):
roots: list[Literal["models","input","output"]] = Field(default_factory=list)
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
@field_validator("roots", mode="before")
@classmethod
def _normalize_roots(cls, v):
if v is None:
return []
if isinstance(v, str):
items = [x.strip().lower() for x in v.split(",")]
elif isinstance(v, list):
items = []
for x in v:
if isinstance(x, str):
items.extend([p.strip().lower() for p in x.split(",")])
else:
return []
out = []
seen = set()
for r in items:
if r in {"models","input","output"} and r not in seen:
out.append(r)
seen.add(r)
return out
class ScheduleAssetScanBody(BaseModel):
roots: list[RootType] = Field(..., min_length=1)
class UploadAssetSpec(BaseModel):
@ -281,3 +271,22 @@ class UploadAssetSpec(BaseModel):
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
return self
class SetPreviewBody(BaseModel):
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
preview_id: Optional[str] = None
@field_validator("preview_id", mode="before")
@classmethod
def _norm_uuid(cls, v):
if v is None:
return None
s = str(v).strip()
if not s:
return None
try:
uuid.UUID(s)
except Exception:
raise ValueError("preview_id must be a UUID")
return s

View File

@ -1,12 +1,13 @@
from datetime import datetime
from typing import Any, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel):
id: str
name: str
asset_hash: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
@ -31,7 +32,7 @@ class AssetsList(BaseModel):
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str
asset_hash: Optional[str]
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: Optional[datetime] = None
@ -46,12 +47,12 @@ class AssetUpdated(BaseModel):
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str
asset_hash: Optional[str]
size: Optional[int] = None
mime_type: Optional[str] = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_hash: Optional[str] = None
preview_id: Optional[str] = None
created_at: Optional[datetime] = None
last_access_time: Optional[datetime] = None
@ -95,7 +96,6 @@ class TagsRemove(BaseModel):
class AssetScanError(BaseModel):
path: str
message: str
phase: Literal["fast", "slow"]
at: Optional[str] = Field(None, description="ISO timestamp")
@ -108,8 +108,6 @@ class AssetScanStatus(BaseModel):
finished_at: Optional[str] = None
discovered: int = 0
processed: int = 0
slow_queue_total: int = 0
slow_queue_finished: int = 0
file_errors: list[AssetScanError] = Field(default_factory=list)

View File

@ -4,38 +4,39 @@ import mimetypes
import os
from typing import Optional, Sequence
from comfy.cli_args import args
from comfy_api.internal import async_to_sync
from .database.db import create_session
from .storage import hashing
from .database.services import (
check_fs_asset_exists_quick,
ingest_fs_asset,
touch_asset_infos_by_fs_path,
list_asset_infos_page,
update_asset_info_full,
get_asset_tags,
list_tags_with_usage,
add_tags_to_asset_info,
remove_tags_from_asset_info,
fetch_asset_info_and_asset,
touch_asset_info_by_id,
delete_asset_info_by_id,
asset_exists_by_hash,
get_asset_by_hash,
create_asset_info_for_existing_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
list_cache_states_by_asset_hash,
asset_info_exists_for_hash,
)
from .api import schemas_in, schemas_out
from ._assets_helpers import (
get_name_and_tags_from_asset_path,
ensure_within_base,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
)
from .api import schemas_in, schemas_out
from .database.db import create_session
from .database.models import Asset
from .database.services import (
add_tags_to_asset_info,
asset_exists_by_hash,
asset_info_exists_for_asset_id,
check_fs_asset_exists_quick,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_by_hash,
get_asset_info_by_id,
get_asset_tags,
ingest_fs_asset,
list_asset_infos_page,
list_cache_states_by_asset_id,
list_tags_with_usage,
remove_tags_from_asset_info,
set_asset_info_preview,
touch_asset_info_by_id,
touch_asset_infos_by_fs_path,
update_asset_info_full,
)
from .storage import hashing
async def asset_exists(*, asset_hash: str) -> bool:
@ -44,29 +45,21 @@ async def asset_exists(*, asset_hash: str) -> bool:
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
if not args.enable_model_processing:
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
if tags is None:
tags = []
try:
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
add_local_asset,
tags=list(dict.fromkeys([*path_tags, *tags])),
file_name=asset_name,
file_path=file_path,
)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
Notes:
- Uses absolute path as the canonical locator for the cache backend.
- Computes BLAKE3 only when the fast existence check indicates it's needed.
- This function ensures the identity row and seeds mtime in asset_cache_state.
"""
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
if not size_bytes:
@ -132,7 +125,7 @@ async def list_assets(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
@ -156,16 +149,17 @@ async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.As
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
preview_hash=info.preview_hash,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
@ -176,20 +170,13 @@ async def resolve_asset_content_for_download(
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
"""
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
Also touches last_access_time (only_if_newer).
Raises:
ValueError if AssetInfo cannot be found
FileNotFoundError if file for Asset cannot be found
"""
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash)
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = ""
for s in states:
if s and s.file_path and os.path.isfile(s.file_path):
@ -214,16 +201,6 @@ async def upload_asset_from_temp_path(
owner_id: str = "",
expected_asset_hash: Optional[str] = None,
) -> schemas_out.AssetCreated:
"""
Finalize an uploaded temp file:
- compute blake3 hash
- if expected_asset_hash provided, verify equality (400 on mismatch at caller)
- if an Asset with the same hash exists: discard temp, create AssetInfo only (no write)
- else resolve destination from tags and atomically move into place
- ingest into DB (assets, locator state, asset_info + tags)
Returns a populated AssetCreated payload.
"""
try:
digest = await hashing.blake3_hash(temp_path)
except Exception as e:
@ -233,7 +210,6 @@ async def upload_asset_from_temp_path(
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
# Fast path: content already known --> no writes, just create a reference
async with await create_session() as session:
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
@ -257,43 +233,37 @@ async def upload_asset_from_temp_path(
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_hash=info.preview_hash,
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
# Resolve destination (only for truly new content)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
# Decide filename
desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name))
ensure_within_base(dest_abs, base_dir)
# Content type based on final name
content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream"
# Atomic move into place
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
# Stat final file
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
# Ingest + build response
async with await create_session() as session:
result = await ingest_fs_asset(
session,
@ -304,7 +274,7 @@ async def upload_asset_from_temp_path(
mime_type=content_type,
info_name=os.path.basename(dest_abs),
owner_id=owner_id,
preview_hash=None,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
@ -324,12 +294,12 @@ async def upload_asset_from_temp_path(
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_hash=info.preview_hash,
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
@ -367,38 +337,74 @@ async def update_asset(
return schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
"""Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default),
delete the Asset row as well and remove all cached files recorded for that asset_hash.
"""
async def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: Optional[str],
owner_id: str = "",
) -> schemas_out.AssetDetail:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_hash = info_row.asset_hash if info_row else None
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
await set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
await session.commit()
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_hash:
if not delete_content_if_orphan or not asset_id:
await session.commit()
return True
still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash)
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash)
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await get_asset_by_hash(session, asset_hash=asset_hash)
asset_row = await session.get(Asset, asset_id)
if asset_row is not None:
await session.delete(asset_row)
@ -439,12 +445,12 @@ async def create_asset_from_hash(
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=info.asset_hash,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_hash=info.preview_hash,
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,

View File

@ -1,52 +1,55 @@
import asyncio
import contextlib
import logging
import os
import uuid
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Callable, Literal, Optional, Sequence
from typing import Literal, Optional
import sqlalchemy as sa
import folder_paths
from . import assets_manager
from .api import schemas_out
from ._assets_helpers import get_comfy_models_folders
from ._assets_helpers import (
collect_models_files,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
list_tree,
new_scan_id,
prefixes_for_root,
ts_to_iso,
)
from .api import schemas_in, schemas_out
from .database.db import create_session
from .database.helpers import (
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
)
from .database.models import Asset, AssetCacheState, AssetInfo
from .database.services import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ensure_seed_for_path,
list_cache_states_by_asset_id,
list_cache_states_with_asset_under_prefixes,
add_missing_tag_for_asset_hash,
remove_missing_tag_for_asset_hash,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
)
LOGGER = logging.getLogger(__name__)
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
SLOW_HASH_CONCURRENCY = 1
@dataclass
class ScanProgress:
scan_id: str
root: RootType
root: schemas_in.RootType
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
scheduled_at: float = field(default_factory=lambda: time.time())
started_at: Optional[float] = None
finished_at: Optional[float] = None
discovered: int = 0
processed: int = 0
slow_queue_total: int = 0
slow_queue_finished: int = 0
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
# Internal diagnostics for logs
_fast_total_seen: int = 0
_fast_clean: int = 0
file_errors: list[dict] = field(default_factory=list)
@dataclass
@ -56,18 +59,14 @@ class SlowQueueState:
closed: bool = False
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
async def start_background_assets_scan():
await fast_reconcile_and_kickoff(progress_cb=_console_cb)
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
def current_statuses() -> schemas_out.AssetScanStatusResponse:
scans = []
for root in ALLOWED_ROOTS:
for root in schemas_in.ALLOWED_ROOTS:
prog = PROGRESS_BY_ROOT.get(root)
if not prog:
continue
@ -75,83 +74,65 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
return schemas_out.AssetScanStatusResponse(scans=scans)
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
"""Schedule scans for the provided roots; returns progress snapshots.
Rules:
- Only roots in {models, input, output} are accepted.
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
- Otherwise a new task is created and started immediately.
- Files with zero size are skipped.
"""
normalized: list[RootType] = []
seen = set()
for r in roots or []:
rr = r.strip().lower()
if rr in ALLOWED_ROOTS and rr not in seen:
normalized.append(rr) # type: ignore
seen.add(rr)
if not normalized:
normalized = list(ALLOWED_ROOTS) # schedule all by default
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
results: list[ScanProgress] = []
for root in normalized:
for root in roots:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
results.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
RUNNING_TASKS[root] = asyncio.create_task(
_pipeline_for_root(root, prog, progress_cb=None),
_run_hash_verify_pipeline(root, prog, state),
name=f"asset-scan:{root}",
)
results.append(prog)
return _status_response_for(results)
async def fast_reconcile_and_kickoff(
roots: Optional[Sequence[str]] = None,
*,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None,
) -> schemas_out.AssetScanStatusResponse:
"""
Startup helper: do the fast pass now (so we know queue size),
start slow hashing in the background, return immediately.
"""
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
snaps: list[ScanProgress] = []
for root in normalized:
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
snaps.append(PROGRESS_BY_ROOT[root])
continue
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
PROGRESS_BY_ROOT[root] = prog
state = SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
prog.status = "running"
prog.started_at = time.time()
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
for r in roots:
try:
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
except Exception as e:
_append_error(prog, phase="fast", path="", message=str(e))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Fast reconcile failed for %s", root)
snaps.append(prog)
await _fast_db_consistency_pass(r)
except Exception as ex:
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
for p in paths:
try:
st = os.stat(p, follow_symlinks=True)
if not int(st.st_size or 0):
continue
size_bytes = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
name, tags = get_name_and_tags_from_asset_path(p)
await _seed_one_async(p, size_bytes, mtime_ns, name, tags)
except OSError:
continue
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
RUNNING_TASKS[root] = asyncio.create_task(
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
name=f"asset-hash:{root}",
async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None:
async with await create_session() as sess:
await ensure_seed_for_path(
sess,
abs_path=p,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
info_name=name,
tags=tags,
owner_id="",
)
snaps.append(prog)
return _status_response_for(snaps)
await sess.commit()
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
@ -163,18 +144,15 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
scan_id=progress.scan_id,
root=progress.root,
status=progress.status,
scheduled_at=_ts_to_iso(progress.scheduled_at),
started_at=_ts_to_iso(progress.started_at),
finished_at=_ts_to_iso(progress.finished_at),
scheduled_at=ts_to_iso(progress.scheduled_at),
started_at=ts_to_iso(progress.started_at),
finished_at=ts_to_iso(progress.finished_at),
discovered=progress.discovered,
processed=progress.processed,
slow_queue_total=progress.slow_queue_total,
slow_queue_finished=progress.slow_queue_finished,
file_errors=[
schemas_out.AssetScanError(
path=e.get("path", ""),
message=e.get("message", ""),
phase=e.get("phase", "slow"),
at=e.get("at"),
)
for e in (progress.file_errors or [])
@ -182,27 +160,100 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
)
async def _pipeline_for_root(
root: RootType,
prog: ScanProgress,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
) -> None:
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
SLOW_STATE_BY_ROOT[root] = state
async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime."""
prefixes = prefixes_for_root(root)
if not prefixes:
return
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
conds.append(AssetCacheState.file_path.like(base + "%"))
async with await create_session() as sess:
rows = (
await sess.execute(
sa.select(
AssetCacheState.id,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
Asset.hash,
AssetCacheState.file_path,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
)
).all()
to_set = []
to_clear = []
for sid, mtime_db, needs_verify, a_hash, fp in rows:
try:
st = os.stat(fp, follow_symlinks=True)
except OSError:
# Missing files are handled by missing-tag reconciliation later.
continue
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
if a_hash is not None:
if mtime_db is None or int(mtime_db) != int(actual_mtime_ns):
if not needs_verify:
to_set.append(sid)
else:
if needs_verify:
to_clear.append(sid)
if to_set:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set))
.values(needs_verify=True)
)
if to_clear:
await sess.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear))
.values(needs_verify=False)
)
await sess.commit()
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
prog.status = "running"
prog.started_at = time.time()
try:
await _reconcile_missing_tags_for_root(root, prog)
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
prefixes = prefixes_for_root(root)
await _refresh_verify_flags_for_root(root, prog)
# collect candidates from DB
async with await create_session() as sess:
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
# dedupe: prioritize verification first
seen = set()
ordered: list[int] = []
for lst in (verify_ids, unhashed_ids):
for sid in lst:
if sid not in seen:
seen.add(sid); ordered.append(sid)
prog.discovered = len(ordered)
# queue up work
for sid in ordered:
await state.queue.put(sid)
state.closed = True
_start_state_workers(root, prog, state)
await _await_state_workers_then_finish(root, prog, state)
except asyncio.CancelledError:
prog.status = "cancelled"
raise
except Exception as exc:
_append_error(prog, phase="slow", path="", message=str(exc))
_append_error(prog, path="", message=str(exc))
prog.status = "failed"
prog.finished_at = time.time()
LOGGER.exception("Asset scan failed for %s", root)
@ -210,110 +261,13 @@ async def _pipeline_for_root(
RUNNING_TASKS.pop(root, None)
async def _fast_reconcile_into_queue(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
) -> None:
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
"""
Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
and queue the rest for slow hashing.
"""
if root == "models":
files = _collect_models_files()
preset_discovered = _count_nonzero_in_list(files)
files_iter = asyncio.Queue()
for p in files:
await files_iter.put(p)
await files_iter.put(None) # sentinel for our local draining loop
elif root == "input":
base = folder_paths.get_input_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
files_iter = await _queue_tree_files(base)
elif root == "output":
base = folder_paths.get_output_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
files_iter = await _queue_tree_files(base)
else:
raise RuntimeError(f"Unsupported root: {root}")
Detect missing files quickly and toggle 'missing' tag per asset_id.
prog.discovered = int(preset_discovered or 0)
queued = 0
checked = 0
clean = 0
async with await create_session() as sess:
while True:
item = await files_iter.get()
files_iter.task_done()
if item is None:
break
abs_path = item
checked += 1
# Stat; skip empty/unreadable
try:
st = os.stat(abs_path, follow_symlinks=True)
if not st.st_size:
continue
size_bytes = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
except OSError as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
continue
try:
known = await check_fs_asset_exists_quick(
sess,
file_path=abs_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
)
except Exception as e:
_append_error(prog, phase="fast", path=abs_path, message=str(e))
known = False
if known:
clean += 1
prog.processed += 1
else:
await state.queue.put(abs_path)
queued += 1
prog.slow_queue_total += 1
if progress_cb:
progress_cb(root, "fast", prog.processed, False, {
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
})
prog._fast_total_seen = checked
prog._fast_clean = clean
if progress_cb:
progress_cb(root, "fast", prog.processed, True, {
"checked": checked,
"clean": clean,
"queued": queued,
"discovered": prog.discovered,
})
state.closed = True
async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None:
"""
Logic for detecting missing Assets files:
- Clear 'missing' only if at least one cached path passes fast check:
exists AND mtime_ns matches AND size matches.
- Otherwise set 'missing'.
Files that exist but fail fast check will be slow-hashed by the normal pipeline,
and ingest_fs_asset will clear 'missing' if they truly match.
Rules:
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
- We consider ALL cache states of the asset (across roots) before tagging.
"""
if root == "models":
bases: list[str] = []
@ -326,232 +280,217 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -
try:
async with await create_session() as sess:
# state + hash + size for the current root
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}}
for state, size_db in rows:
h = state.asset_hash
acc = by_hash.get(h)
# Track fast_ok within the scanned root and whether the asset is hashed
by_asset: dict[str, dict[str, bool]] = {}
for state, a_hash, size_db in rows:
aid = state.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"any_fast_ok": False}
by_hash[h] = acc
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
by_asset[aid] = acc
try:
st = os.stat(state.file_path, follow_symlinks=True)
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
fast_ok = False
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
if int(size_db) > 0 and int(st.st_size) == int(size_db):
fast_ok = True
if acc["hashed"]:
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]):
fast_ok = True
if fast_ok:
acc["any_fast_ok"] = True
acc["any_fast_ok_here"] = True
except FileNotFoundError:
pass # not fast_ok
pass
except OSError as e:
_append_error(prog, phase="fast", path=state.file_path, message=str(e))
_append_error(prog, path=state.file_path, message=str(e))
for h, acc in by_hash.items():
# Decide per asset, considering ALL its states (not just this root)
for aid, acc in by_asset.items():
try:
if acc["any_fast_ok"]:
await remove_missing_tag_for_asset_hash(sess, asset_hash=h)
if not acc["hashed"]:
# Never tag seed assets as missing
continue
any_fast_ok_global = acc["any_fast_ok_here"]
if not any_fast_ok_global:
# Check other states outside this root
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
for st in others:
try:
s = os.stat(st.file_path, follow_symlinks=True)
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]:
any_fast_ok_global = True
break
except OSError:
continue
if any_fast_ok_global:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
else:
await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic")
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception as ex:
_append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}")
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
await sess.commit()
except Exception as e:
_append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}")
_append_error(prog, path="", message=f"reconcile failed: {e}")
def _start_slow_workers(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
) -> None:
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
return
async def _worker(_worker_id: int):
async def _worker(_wid: int):
while True:
item = await state.queue.get()
sid = await state.queue.get()
try:
if item is None:
if sid is None:
return
try:
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
except Exception as e:
_append_error(prog, phase="slow", path=item, message=str(e))
async with await create_session() as sess:
# Optional: fetch path for better error messages
st = await sess.get(AssetCacheState, sid)
try:
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
await sess.commit()
except Exception as e:
path = st.file_path if st else f"state:{sid}"
_append_error(prog, path=path, message=str(e))
raise
except Exception:
pass
finally:
# Slow queue finished for this item; also counts toward overall processed
prog.slow_queue_finished += 1
prog.processed += 1
if progress_cb:
progress_cb(root, "slow", prog.processed, False, {
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
})
finally:
state.queue.task_done()
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
state.workers = [
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
for i in range(SLOW_HASH_CONCURRENCY)
]
async def _close_when_empty():
# When the fast phase closed the queue, push sentinels to end workers
async def _close_when_ready():
while not state.closed:
await asyncio.sleep(0.05)
for _ in range(SLOW_HASH_CONCURRENCY):
await state.queue.put(None)
asyncio.create_task(_close_when_empty())
asyncio.create_task(_close_when_ready())
async def _await_workers_then_finish(
root: RootType,
prog: ScanProgress,
state: SlowQueueState,
*,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
) -> None:
async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
if state.workers:
await asyncio.gather(*state.workers, return_exceptions=True)
await _reconcile_missing_tags_for_root(root, prog)
prog.finished_at = time.time()
prog.status = "completed"
if progress_cb:
progress_cb(root, "slow", prog.processed, True, {
"slow_queue_finished": prog.slow_queue_finished,
"slow_queue_total": prog.slow_queue_total,
})
def _collect_models_files() -> list[str]:
"""Collect absolute file paths from configured model buckets under models_dir."""
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
# ensure within allowed bases
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int:
if not os.path.isdir(base_abs):
return 0
total = 0
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
if not only_nonzero:
total += len(filenames)
else:
for name in filenames:
with contextlib.suppress(OSError):
st = os.stat(os.path.join(dirpath, name), follow_symlinks=True)
if st.st_size:
total += 1
return total
def _count_nonzero_in_list(paths: list[str]) -> int:
cnt = 0
for p in paths:
with contextlib.suppress(OSError):
st = os.stat(p, follow_symlinks=True)
if st.st_size:
cnt += 1
return cnt
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
"""
Walk base_dir in a worker thread and return a queue prefilled with all paths,
terminated by a single None sentinel for the draining loop in fast reconcile.
"""
q: asyncio.Queue = asyncio.Queue()
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
await q.put(None)
return q
def _walk_list():
paths: list[str] = []
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
paths.append(os.path.abspath(os.path.join(dirpath, name)))
return paths
for p in await asyncio.to_thread(_walk_list):
await q.put(p)
await q.put(None)
return q
def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
prog.file_errors.append({
"path": path,
"message": message,
"phase": phase,
"at": _ts_to_iso(time.time()),
"at": ts_to_iso(time.time()),
})
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
try:
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
except Exception:
return None
async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None:
"""
Quick pass over asset_cache_state for `root`:
- If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos.
- If file missing and Asset.hash is NOT NULL:
* If at least one state for this Asset is fast-ok, delete the missing state.
* If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset.
- If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag.
"""
prefixes = prefixes_for_root(root)
if not prefixes:
return
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
conds.append(AssetCacheState.file_path.like(base + "%"))
def _new_scan_id(root: RootType) -> str:
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
async with await create_session() as sess:
if not conds:
return
def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict):
if phase == "fast":
if finished:
logging.info(
"[assets][%s] fast done: processed=%s/%s queued=%s",
root,
total_processed,
e["discovered"],
e["queued"],
)
elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress
logging.info(
"[assets][%s] fast progress: processed=%s/%s",
root,
total_processed,
e["discovered"],
)
elif phase == "slow":
if finished:
if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0):
logging.info(
"[assets][%s] slow done: %s/%s",
root,
e.get("slow_queue_finished", 0),
e.get("slow_queue_total", 0),
)
elif e.get('slow_queue_finished', 0) % 3 == 0:
logging.info(
"[assets][%s] slow progress: %s/%s",
root,
e.get("slow_queue_finished", 0),
e.get("slow_queue_total", 0),
rows = (
await sess.execute(
sa.select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
# Group by asset_id with status per state
by_asset: dict[str, dict] = {}
for st, a_hash, a_size in rows:
aid = st.asset_id
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
exists = False
fast_ok = False
try:
s = os.stat(st.file_path, follow_symlinks=True)
exists = True
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]:
fast_ok = True
except FileNotFoundError:
exists = False
except OSError as ex:
exists = False
LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex)
acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok})
# Apply actions
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
missing_states = [s["obj"] for s in states if not s["exists"]]
if a_hash is None:
# Seed asset: if all states gone (and in practice there is only one), remove the whole Asset
if states and all_missing:
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = await sess.get(Asset, aid)
if asset:
await sess.delete(asset)
# else leave it for the slow scan to verify/rehash
else:
if any_fast_ok:
# Remove 'missing' and delete just the stale state rows
for st in missing_states:
try:
await sess.delete(await sess.get(AssetCacheState, st.id))
except Exception:
pass
try:
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
except Exception:
pass
else:
# No fast-ok path: mark as missing
try:
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
except Exception:
pass
await sess.flush()
await sess.commit()

View File

@ -1,186 +0,0 @@
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, exists
from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta
from .._assets_helpers import normalize_tags
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
if to_create:
session.add_all(to_create)
await session.flush()
by_name.update({t.name: t for t in to_create})
return [by_name[n] for n in wanted]
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply metadata filters using the projection table asset_info_meta.
Semantics:
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
- For None: key is missing OR key has explicit null (val_json IS NULL).
- For list values: ANY-of the list elements matches (EXISTS for any).
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
"""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
# Missing OR null:
if value is None:
# either: no row for key OR a row for key with explicit null
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
# Typed scalar matches:
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
# store as Decimal for equality against NUMERIC(38,10)
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
# Complex: compare JSON (no index, but supported)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def is_scalar(v: Any) -> bool:
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value: Any) -> list[dict]:
"""
Turn a metadata key/value into one or more projection rows:
- scalar -> one row (ordinal=0) in the proper typed column
- list of scalars -> one row per element with ordinal=i
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
- None -> single row with all value columns NULL
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
# Fallback to json
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# list contains objects -> one val_json per element
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
# Dict or any other structure -> single json row
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -4,14 +4,20 @@ import shutil
from contextlib import asynccontextmanager
from typing import Optional
from comfy.cli_args import args
from alembic import command
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine, text
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from comfy.cli_args import args
LOGGER = logging.getLogger(__name__)
ENGINE: Optional[AsyncEngine] = None

View File

@ -0,0 +1,23 @@
from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv
from .tags import (
add_missing_tag_for_asset_hash,
add_missing_tag_for_asset_id,
ensure_tags_exist,
remove_missing_tag_for_asset_hash,
remove_missing_tag_for_asset_id,
)
__all__ = [
"apply_tag_filters",
"apply_metadata_filter",
"is_scalar",
"project_kv",
"ensure_tags_exist",
"add_missing_tag_for_asset_id",
"add_missing_tag_for_asset_hash",
"remove_missing_tag_for_asset_id",
"remove_missing_tag_for_asset_hash",
"visible_owner_clause",
]

View File

@ -0,0 +1,87 @@
from typing import Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from ..._assets_helpers import normalize_tags
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Optional[Sequence[str]],
exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -0,0 +1,12 @@
import sqlalchemy as sa
from ..models import AssetInfo
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@ -0,0 +1,64 @@
from decimal import Decimal
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value):
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -0,0 +1,102 @@
from typing import Iterable
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from ..._assets_helpers import normalize_tags
from ..models import Asset, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
wanted = normalize_tags(list(names))
if not wanted:
return []
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
by_name = {t.name: t for t in existing}
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
if to_create:
session.add_all(to_create)
await session.flush()
by_name.update({t.name: t for t in to_create})
return [by_name[n] for n in wanted]
async def add_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
origin: str = "automatic",
) -> int:
"""Ensure every AssetInfo for asset_id has 'missing' tag."""
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
if not ids:
return 0
existing = {
asset_info_id
for (asset_info_id,) in (
await session.execute(
select(AssetInfoTag.asset_info_id).where(
AssetInfoTag.asset_info_id.in_(ids),
AssetInfoTag.tag_name == "missing",
)
)
).all()
}
to_add = [i for i in ids if i not in existing]
if not to_add:
return 0
now = utcnow()
session.add_all(
[
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
for i in to_add
]
)
await session.flush()
return len(to_add)
async def add_missing_tag_for_asset_hash(
session: AsyncSession,
*,
asset_hash: str,
origin: str = "automatic",
) -> int:
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
if not asset:
return 0
return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin)
async def remove_missing_tag_for_asset_id(
session: AsyncSession,
*,
asset_id: str,
) -> int:
"""Remove the 'missing' tag from all AssetInfos for asset_id."""
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
if not ids:
return 0
res = await session.execute(
delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(ids),
AssetInfoTag.tag_name == "missing",
)
)
await session.flush()
return int(res.rowcount or 0)
async def remove_missing_tag_for_asset_hash(
session: AsyncSession,
*,
asset_hash: str,
) -> int:
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
if not asset:
return 0
return await remove_missing_tag_for_asset_id(session, asset_id=asset.id)

View File

@ -1,27 +1,26 @@
import uuid
from datetime import datetime
from typing import Any, Optional
import uuid
from sqlalchemy import (
Integer,
JSON,
BigInteger,
Boolean,
CheckConstraint,
DateTime,
ForeignKey,
Index,
UniqueConstraint,
JSON,
Integer,
Numeric,
String,
Text,
CheckConstraint,
Numeric,
Boolean,
UniqueConstraint,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
from .timeutil import utcnow
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
@ -46,7 +45,8 @@ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
class Asset(Base):
__tablename__ = "assets"
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
@ -56,8 +56,8 @@ class Asset(Base):
infos: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
foreign_keys=lambda: [AssetInfo.asset_hash],
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
@ -65,8 +65,8 @@ class Asset(Base):
preview_of: Mapped[list["AssetInfo"]] = relationship(
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
foreign_keys=lambda: [AssetInfo.preview_hash],
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
@ -76,36 +76,32 @@ class Asset(Base):
passive_deletes=True,
)
locations: Mapped[list["AssetLocation"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset hash={self.hash[:12]}>"
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
@ -114,27 +110,7 @@ class AssetCacheState(Base):
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
class AssetLocation(Base):
__tablename__ = "asset_locations"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locations")
__table_args__ = (
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
Index("ix_asset_locations_hash", "asset_hash"),
Index("ix_asset_locations_provider", "provider"),
)
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
@ -143,31 +119,23 @@ class AssetInfo(Base):
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
)
preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
# Relationships
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
foreign_keys=[asset_hash],
foreign_keys=[asset_id],
lazy="selectin",
)
preview_asset: Mapped[Optional[Asset]] = relationship(
"Asset",
back_populates="preview_of",
foreign_keys=[preview_hash],
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
@ -186,16 +154,16 @@ class AssetInfo(Base):
tags: Mapped[list["Tag"]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
lazy="joined",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
)
__table_args__ = (
UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_hash", "asset_hash"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
@ -207,7 +175,7 @@ class AssetInfo(Base):
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
class AssetInfoMeta(Base):

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,56 @@
from .content import (
check_fs_asset_exists_quick,
compute_hash_and_dedup_for_cache_state,
ensure_seed_for_path,
ingest_fs_asset,
list_cache_states_with_asset_under_prefixes,
list_unhashed_candidates_under_prefixes,
list_verify_candidates_under_prefixes,
redirect_all_references_then_delete_asset,
touch_asset_infos_by_fs_path,
)
from .info import (
add_tags_to_asset_info,
create_asset_info_for_existing_asset,
delete_asset_info_by_id,
fetch_asset_info_and_asset,
fetch_asset_info_asset_and_tags,
get_asset_tags,
list_asset_infos_page,
list_tags_with_usage,
remove_tags_from_asset_info,
replace_asset_info_metadata_projection,
set_asset_info_preview,
set_asset_info_tags,
touch_asset_info_by_id,
update_asset_info_full,
)
from .queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
get_cache_state_by_asset_id,
list_cache_states_by_asset_id,
)
__all__ = [
# queries
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
"get_cache_state_by_asset_id",
"list_cache_states_by_asset_id",
# info
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
"update_asset_info_full", "replace_asset_info_metadata_projection",
"touch_asset_info_by_id", "delete_asset_info_by_id",
"add_tags_to_asset_info", "remove_tags_from_asset_info",
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
# content
"check_fs_asset_exists_quick", "ensure_seed_for_path",
"redirect_all_references_then_delete_asset",
"compute_hash_and_dedup_for_cache_state",
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
"list_cache_states_with_asset_under_prefixes",
]

View File

@ -0,0 +1,746 @@
import logging
import os
from datetime import datetime
from typing import Any, Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import postgresql as d_pg
from sqlalchemy.dialects import sqlite as d_sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import noload
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
from ...storage import hashing as hashing_mod
from ..helpers import (
ensure_tags_exist,
remove_missing_tag_for_asset_id,
)
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
from ..timeutil import utcnow
from .info import replace_asset_info_metadata_projection
async def check_fs_asset_exists_quick(
session: AsyncSession,
*,
file_path: str,
size_bytes: Optional[int] = None,
mtime_ns: Optional[int] = None,
) -> bool:
"""Return True if a cache row exists for this absolute path and (optionally) mtime/size match."""
locator = os.path.abspath(file_path)
stmt = (
sa.select(sa.literal(True))
.select_from(AssetCacheState)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(AssetCacheState.file_path == locator)
.limit(1)
)
conds = []
if mtime_ns is not None:
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
if conds:
stmt = stmt.where(*conds)
row = (await session.execute(stmt)).first()
return row is not None
async def ensure_seed_for_path(
session: AsyncSession,
*,
abs_path: str,
size_bytes: int,
mtime_ns: int,
info_name: str,
tags: Sequence[str],
owner_id: str = "",
) -> str:
"""Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id."""
locator = os.path.abspath(abs_path)
now = utcnow()
state = (
await session.execute(
sa.select(AssetCacheState, Asset)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(AssetCacheState.file_path == locator)
.limit(1)
)
).first()
if state:
state_row: AssetCacheState = state[0]
asset_row: Asset = state[1]
changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns)
if changed:
state_row.mtime_ns = int(mtime_ns)
state_row.needs_verify = True
if asset_row.size_bytes == 0 and size_bytes > 0:
asset_row.size_bytes = int(size_bytes)
return asset_row.id
# Create new asset (hash=NULL)
asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now)
session.add(asset)
await session.flush() # to get id
cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False)
session.add(cs)
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
# Attach tags
want = normalize_tags(tags)
if want:
await ensure_tags_exist(session, want, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now)
for t in want
])
await session.flush()
return asset.id
async def redirect_all_references_then_delete_asset(
session: AsyncSession,
*,
duplicate_asset_id: str,
canonical_asset_id: str,
) -> None:
"""
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
- Otherwise, simply repoint the AssetInfo.asset_id.
- Always retarget AssetCacheState rows.
- Finally delete the duplicate Asset row.
"""
if duplicate_asset_id == canonical_asset_id:
return
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
dup_infos = (
await session.execute(
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
)
).unique().scalars().all()
for info in dup_infos:
# Try to find an existing collision on canonical
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == canonical_asset_id,
AssetInfo.owner_id == info.owner_id,
AssetInfo.name == info.name,
)
.limit(1)
)
).unique().scalars().first()
if existing:
# Merge metadata (prefer existing keys, fill gaps from duplicate)
merged_meta = dict(existing.user_metadata or {})
other_meta = info.user_metadata or {}
for k, v in other_meta.items():
if k not in merged_meta:
merged_meta[k] = v
if merged_meta != (existing.user_metadata or {}):
await replace_asset_info_metadata_projection(
session,
asset_info_id=existing.id,
user_metadata=merged_meta,
)
# Merge tags (union)
existing_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
)
).all()
}
from_tags = {
t for (t,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
)
).all()
}
to_add = sorted(from_tags - existing_tags)
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
now = utcnow()
session.add_all([
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
for t in to_add
])
await session.flush()
# Merge preview and times
if existing.preview_id is None and info.preview_id is not None:
existing.preview_id = info.preview_id
if info.last_access_time and (
existing.last_access_time is None or info.last_access_time > existing.last_access_time
):
existing.last_access_time = info.last_access_time
existing.updated_at = utcnow()
await session.flush()
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
await session.delete(info)
await session.flush()
else:
# Simple retarget
info.asset_id = canonical_asset_id
info.updated_at = utcnow()
await session.flush()
# 2) Repoint cache states and previews
await session.execute(
sa.update(AssetCacheState)
.where(AssetCacheState.asset_id == duplicate_asset_id)
.values(asset_id=canonical_asset_id)
)
await session.execute(
sa.update(AssetInfo)
.where(AssetInfo.preview_id == duplicate_asset_id)
.values(preview_id=canonical_asset_id)
)
# 3) Remove duplicate Asset
dup = await session.get(Asset, duplicate_asset_id)
if dup:
await session.delete(dup)
await session.flush()
async def compute_hash_and_dedup_for_cache_state(
session: AsyncSession,
*,
state_id: int,
) -> Optional[str]:
"""
Compute hash for the given cache state, deduplicate, and settle verify cases.
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
"""
state = await session.get(AssetCacheState, state_id)
if not state:
return None
path = state.file_path
try:
if not os.path.isfile(path):
# File vanished: drop the state. If the Asset was a seed (hash NULL)
# and has no other states, drop the Asset too.
asset = await session.get(Asset, state.asset_id)
await session.delete(state)
await session.flush()
if asset and asset.hash is None:
remaining = (
await session.execute(
sa.select(sa.func.count())
.select_from(AssetCacheState)
.where(AssetCacheState.asset_id == asset.id)
)
).scalar_one()
if int(remaining or 0) == 0:
await session.delete(asset)
await session.flush()
return None
digest = await hashing_mod.blake3_hash(path)
new_hash = f"blake3:{digest}"
st = os.stat(path, follow_symlinks=True)
new_size = int(st.st_size)
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
# Current asset of this state
this_asset = await session.get(Asset, state.asset_id)
# If the state got orphaned somehow (race), just reattach appropriately.
if not this_asset:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
state.asset_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
state.asset_id = new_asset.id
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
except Exception:
pass
await session.flush()
return state.asset_id
# 1) Seed asset case (hash is NULL): claim or merge into canonical
if this_asset.hash is None:
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
# Merge seed asset into canonical (safe, collision-aware)
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
# Refresh state after the merge
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
except Exception:
pass
await session.flush()
return canonical.id
# No canonical: try to claim the hash; handle races with a SAVEPOINT
try:
async with session.begin_nested():
this_asset.hash = new_hash
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
await session.flush()
except IntegrityError:
# Someone else claimed it concurrently; fetch canonical and merge
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical and canonical.id != this_asset.id:
await redirect_all_references_then_delete_asset(
session,
duplicate_asset_id=this_asset.id,
canonical_asset_id=canonical.id,
)
state = await session.get(AssetCacheState, state_id)
if state:
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
except Exception:
pass
await session.flush()
return canonical.id
# If we got here, the integrity error was not about hash uniqueness
raise
# Claimed successfully
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
except Exception:
pass
await session.flush()
return this_asset.id
# 2) Verify case for hashed assets
if this_asset.hash == new_hash:
# Content unchanged; tidy up sizes/mtime
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
this_asset.size_bytes = new_size
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
except Exception:
pass
await session.flush()
return this_asset.id
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
canonical = (
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
).scalars().first()
if canonical:
target_id = canonical.id
else:
now = utcnow()
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
session.add(new_asset)
await session.flush()
target_id = new_asset.id
state.asset_id = target_id
state.mtime_ns = mtime_ns
state.needs_verify = False
try:
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
except Exception:
pass
await session.flush()
return target_id
except Exception:
# Propagate; caller records the error and continues the worker.
raise
async def list_unhashed_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> list[int]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
conds.append(AssetCacheState.file_path.like(base + "%"))
rows = (
await session.execute(
sa.select(AssetCacheState.id)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(Asset.hash.is_(None))
.where(sa.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).scalars().all()
seen = set()
result: list[int] = []
for sid in rows:
st = await session.get(AssetCacheState, sid)
if st and st.asset_id not in seen:
seen.add(st.asset_id)
result.append(sid)
return result
async def list_verify_candidates_under_prefixes(
session: AsyncSession, *, prefixes: Sequence[str]
) -> Union[list[int], Sequence[int]]:
if not prefixes:
return []
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
conds.append(AssetCacheState.file_path.like(base + "%"))
return (
await session.execute(
sa.select(AssetCacheState.id)
.where(AssetCacheState.needs_verify.is_(True))
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def ingest_fs_asset(
session: AsyncSession,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: Optional[str] = None,
info_name: Optional[str] = None,
owner_id: str = "",
preview_id: Optional[str] = None,
user_metadata: Optional[dict] = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not await session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
async with session.begin_nested():
asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now)
session.add(asset)
await session.flush()
out["asset_created"] = True
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
dialect = session.bind.dialect.name
if dialect == "sqlite":
ins = (
d_sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
elif dialect == "postgresql":
ins = (
d_pg.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
else:
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
res = await session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = await session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
# upsert by (asset_id, owner_id, name)
try:
async with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
await session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
await session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
await ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
await session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
await session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = (
await session.execute(
select(AssetCacheState.file_path)
.where(AssetCacheState.asset_id == asset.id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
async def touch_asset_infos_by_fs_path(
session: AsyncSession,
*,
file_path: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
locator = os.path.abspath(file_path)
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(AssetCacheState)
.where(
AssetCacheState.asset_id == AssetInfo.asset_id,
AssetCacheState.file_path == locator,
)
)
)
if only_if_newer:
stmt = stmt.where(
sa.or_(
AssetInfo.last_access_time.is_(None),
AssetInfo.last_access_time < ts,
)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def list_cache_states_with_asset_under_prefixes(
session: AsyncSession,
*,
prefixes: Sequence[str],
) -> list[tuple[AssetCacheState, Optional[str], int]]:
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
if not prefixes:
return []
conds = []
for p in prefixes:
if not p:
continue
base = os.path.abspath(p)
if not base.endswith(os.sep):
base = base + os.sep
conds.append(AssetCacheState.file_path.like(base + "%"))
if not conds:
return []
rows = (
await session.execute(
select(AssetCacheState, Asset.hash, Asset.size_bytes)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sa.or_(*conds))
.order_by(AssetCacheState.id.asc())
)
).all()
return [(r[0], r[1], int(r[2] or 0)) for r in rows]

View File

@ -0,0 +1,579 @@
from collections import defaultdict
from datetime import datetime
from typing import Any, Optional, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import contains_eager, noload
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
from ..helpers import (
apply_metadata_filter,
apply_tag_filters,
ensure_tags_exist,
project_kv,
visible_owner_clause,
)
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from ..timeutil import utcnow
from .queries import get_asset_by_hash, get_cache_state_by_asset_id
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
metadata_filter: Optional[dict] = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((await session.execute(count_stmt)).scalar_one() or 0)
infos = (await session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = await session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: str,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (await session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
asset_hash: str,
name: str,
user_metadata: Optional[dict] = None,
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
async with session.begin_nested():
session.add(info)
await session.flush()
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
state = await get_cache_state_by_asset_id(session, asset_id=asset.id)
if state and state.file_path:
computed_filename = compute_model_relative_filename(state.file_path)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
await replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
async def set_asset_info_tags(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
await ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
await session.flush()
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
await session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
async def update_asset_info_full(
session: AsyncSession,
*,
asset_info_id: str,
name: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id)
if state and state.file_path:
computed_filename = compute_model_relative_filename(state.file_path)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
await replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
await set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
await session.flush()
return info
async def replace_asset_info_metadata_projection(
session: AsyncSession,
*,
asset_info_id: str,
user_metadata: Optional[dict],
) -> None:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
await session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
await session.flush()
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: str,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
res = await session.execute(delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
))
return bool(res.rowcount)
async def add_tags_to_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
await ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
async with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
await session.flush()
except IntegrityError:
await nested.rollback()
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
async def remove_tags_from_asset_info(
session: AsyncSession,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
await session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
await session.flush()
total = await get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
async def list_tags_with_usage(
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (await session.execute(q.limit(limit).offset(offset))).all()
total = (await session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
await session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
async def set_asset_info_preview(
session: AsyncSession,
*,
asset_info_id: str,
preview_asset_id: Optional[str],
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not await session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
await session.flush()

View File

@ -0,0 +1,59 @@
from typing import Optional, Sequence, Union
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from ..models import Asset, AssetCacheState, AssetInfo
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
row = (
await session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
return (
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (await session.execute(q)).first() is not None
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_id(
session: AsyncSession, *, asset_id: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()

View File

@ -212,7 +212,6 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
if comfy.options.args_parsing:

View File

@ -279,11 +279,11 @@ def cleanup_temp():
shutil.rmtree(temp_dir, ignore_errors=True)
async def setup_database():
from app import init_db_engine, start_background_assets_scan
from app import init_db_engine, sync_seed_assets
await init_db_engine()
if not args.disable_assets_autoscan:
await start_background_assets_scan()
await sync_seed_assets(["models", "input", "output"])
def start_comfyui(asyncio_loop=None):

View File

@ -37,7 +37,7 @@ from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from app.api.assets_routes import register_assets_system
from app import sync_seed_assets, register_assets_system
from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message):
@ -629,6 +629,7 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
await sync_seed_assets(["models"])
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:

View File

@ -118,6 +118,16 @@ async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, se
assert rh2.status == 404
@pytest.mark.asyncio
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
assert rh.status == 400
body = await rh.read()
assert body == b""
@pytest.mark.asyncio
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
bogus = str(uuid.uuid4())
@ -166,12 +176,3 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.asyncio
async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str):
# Invalid format
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3:
jb = await rh3.json()
assert rh3.status == 400
assert jb is None # HEAD request should not include "body" in response

View File

@ -66,23 +66,32 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
b1 = await r1.json()
assert r1.status == 200, b1
# normalized and deduplicated
assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"]
# normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
g = await rg.json()
assert rg.status == 200
tags_now = set(g["tags"])
assert "newtag" in tags_now
assert "beta" in tags_now
assert {"newtag", "beta"}.issubset(tags_now)
# Remove a tag and a non-existent tag
payload_del = {"tags": ["newtag", "does-not-exist"]}
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
b2 = await r2.json()
assert r2.status == 200
assert "newtag" in b2["removed"]
assert "does-not-exist" in b2["not_present"]
assert set(b2["removed"]) == {"newtag"}
assert set(b2["not_present"]) == {"does-not-exist"}
# Verify remaining tags after deletion
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
g2 = await rg2.json()
assert rg2.status == 200
tags_later = set(g2["tags"])
assert "newtag" not in tags_later
assert "beta" in tags_later # still present
@pytest.mark.asyncio

View File

@ -206,7 +206,7 @@ async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_b
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
assert body["error"]["message"].startswith("unknown models category")
@pytest.mark.asyncio