diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index b180edacc..9fb80ea8c 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -1,3 +1,4 @@ +# File: /alembic_db/versions/0001_assets.py """initial assets schema + per-asset state cache Revision ID: 0001_assets @@ -22,15 +23,12 @@ def upgrade() -> None: sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), - sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), - sa.Column("storage_locator", sa.Text(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), ) op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) - op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"]) # ASSETS_INFO: user-visible references (mutable metadata) op.create_table( @@ -52,11 +50,12 @@ def upgrade() -> None: op.create_index("ix_assets_info_name", "assets_info", ["name"]) op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) + op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"]) # TAGS: normalized tag vocabulary op.create_table( "tags", - sa.Column("name", sa.String(length=128), primary_key=True), + sa.Column("name", sa.String(length=512), primary_key=True), sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), ) @@ -65,8 +64,8 @@ def upgrade() -> None: # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo op.create_table( "asset_info_tags", - sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), + sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), @@ -75,15 +74,15 @@ def upgrade() -> None: op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) - # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records + # ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset op.create_table( - "asset_locator_state", + "asset_cache_state", sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file sa.Column("mtime_ns", sa.BigInteger(), nullable=True), - sa.Column("etag", sa.String(length=256), nullable=True), - sa.Column("last_modified", sa.String(length=128), nullable=True), - sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), ) + op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"]) # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting op.create_table( @@ -102,6 +101,21 @@ def upgrade() -> None: op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) + # ASSET_LOCATIONS: remote locations per asset + op.create_table( + "asset_locations", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False), + sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs" + sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob" + sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True), + sa.Column("etag", sa.String(length=256), nullable=True), + sa.Column("last_modified", sa.String(length=128), nullable=True), + sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), + ) + op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"]) + op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"]) + # Tags vocabulary for models tags_table = sa.table( "tags", @@ -143,13 +157,18 @@ def upgrade() -> None: def downgrade() -> None: + op.drop_index("ix_asset_locations_provider", table_name="asset_locations") + op.drop_index("ix_asset_locations_hash", table_name="asset_locations") + op.drop_table("asset_locations") + op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_table("asset_info_meta") - op.drop_table("asset_locator_state") + op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state") + op.drop_table("asset_cache_state") op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") @@ -159,6 +178,7 @@ def downgrade() -> None: op.drop_table("tags") op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") + op.drop_index("ix_assets_info_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") @@ -166,6 +186,5 @@ def downgrade() -> None: op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_table("assets_info") - op.drop_index("ix_assets_backend_locator", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_table("assets") diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 4f1ad4446..9fd3600f1 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -105,7 +105,10 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: if root == "models": if len(tags) < 2: raise ValueError("at least two tags required for model asset") - bases = folder_paths.folder_names_and_paths[tags[1]][0] + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") if not bases: raise ValueError(f"no base path configured for category '{tags[1]}'") base_dir = os.path.abspath(bases[0]) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 7fbc69467..c0dde7909 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -9,7 +9,7 @@ from pydantic import ValidationError import folder_paths from .. import assets_manager, assets_scanner -from . import schemas_in +from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() @@ -20,6 +20,9 @@ async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() if not hash_str or ":" not in hash_str: return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") exists = await assets_manager.asset_exists(asset_hash=hash_str) return web.Response(status=200 if exists else 404) @@ -69,7 +72,7 @@ async def download_asset_content(request: web.Request) -> web.Response: except FileNotFoundError: return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") - quoted = filename.replace('"', "'") + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' resp = web.FileResponse(abs_path) @@ -115,6 +118,7 @@ async def upload_asset(request: web.Request) -> web.Response: user_metadata_raw: Optional[str] = None file_written = 0 + tmp_path: Optional[str] = None while True: field = await reader.next() if field is None: @@ -173,6 +177,8 @@ async def upload_asset(request: web.Request) -> web.Response: return _validation_error_response("INVALID_BODY", ve) if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'") try: @@ -182,12 +188,14 @@ async def upload_asset(request: web.Request) -> web.Response: client_filename=file_client_name, ) return web.json_response(created.model_dump(mode="json"), status=201) + except ValueError: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") except Exception: - try: - if os.path.exists(tmp_path): - os.remove(tmp_path) - finally: - return _error_response(500, "INTERNAL", "Unexpected server error.") + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response(500, "INTERNAL", "Unexpected server error.") @ROUTES.put("/api/assets/{id}") @@ -341,6 +349,7 @@ async def get_asset_scan_status(request: web.Request) -> web.Response: states = assets_scanner.current_statuses() if root in {"models", "input", "output"}: states = [s for s in states.scans if s.root == root] # type: ignore + states = schemas_out.AssetScanStatusResponse(scans=states) return web.json_response(states.model_dump(mode="json"), status=200) diff --git a/app/assets_fetcher.py b/app/assets_fetcher.py new file mode 100644 index 000000000..ea1c8ed00 --- /dev/null +++ b/app/assets_fetcher.py @@ -0,0 +1,132 @@ +from __future__ import annotations +import asyncio +import os +import tempfile +from typing import Optional +import mimetypes +import aiohttp + +from .storage.hashing import blake3_hash_sync +from .database.db import create_session +from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash +from .resolvers import resolve_asset +from ._assets_helpers import resolve_destination_from_tags, ensure_within_base + +_FETCH_LOCKS: dict[str, asyncio.Lock] = {} + + +def _sanitize_filename(name: str) -> str: + return os.path.basename((name or "").strip()) or "file" + + +async def ensure_asset_cached( + asset_hash: str, + *, + preferred_name: Optional[str] = None, + tags_hint: Optional[list[str]] = None, +) -> str: + """ + Ensure there is a verified local file for `asset_hash` in the correct Comfy folder. + Policy: + - Resolver must provide valid tags (root and, for models, category). + - If target path already exists: + * if hash matches -> reuse & ingest + * else -> remove and overwrite with the correct content + """ + lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock()) + async with lock: + # 1) If we already have a state -> trust the path + async with await create_session() as sess: + state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash) + if state and os.path.isfile(state.file_path): + return state.file_path + + # 2) Resolve remote location + placement hints (must include valid tags) + res = await resolve_asset(asset_hash) + if not res: + raise FileNotFoundError(f"No resolver/locations for {asset_hash}") + + placement_tags = tags_hint or res.tags + if not placement_tags: + raise ValueError(f"Resolver did not provide placement tags for {asset_hash}") + + name_hint = res.filename or preferred_name or asset_hash.replace(":", "_") + safe_name = _sanitize_filename(name_hint) + + # 3) Map tags -> destination (strict: raises if invalid root or models category) + base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + final_path = os.path.abspath(os.path.join(dest_dir, safe_name)) + ensure_within_base(final_path, base_dir) + + # 4) If target path exists, try to reuse; else delete invalid cache + if os.path.exists(final_path) and os.path.isfile(final_path): + existing_digest = blake3_hash_sync(final_path) + if f"blake3:{existing_digest}" == asset_hash: + size_bytes = os.path.getsize(final_path) + mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) + async with await create_session() as sess: + await ingest_fs_asset( + sess, + asset_hash=asset_hash, + abs_path=final_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=None, + info_name=None, + tags=(), + ) + await sess.commit() + return final_path + else: + # Invalid cache: remove before re-downloading + os.remove(final_path) + + # 5) Download to temp next to destination + timeout = aiohttp.ClientTimeout(total=60 * 30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.get(res.download_url, headers=dict(res.headers)) as resp: + resp.raise_for_status() + cl = resp.headers.get("Content-Length") + if res.expected_size and cl and int(cl) != int(res.expected_size): + raise ValueError("server Content-Length does not match expected size") + with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp: + tmp_path = tmp.name + async for chunk in resp.content.iter_chunked(8 * 1024 * 1024): + if chunk: + tmp.write(chunk) + + # 6) Verify content hash + digest = blake3_hash_sync(tmp_path) + canonical = f"blake3:{digest}" + if canonical != asset_hash: + try: + os.remove(tmp_path) + finally: + raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}") + + # 7) Atomically move into place (we already removed an invalid file if it existed) + if os.path.exists(final_path): + os.remove(final_path) + os.replace(tmp_path, final_path) + + # 8) Record identity + cache state (+ mime type) + size_bytes = os.path.getsize(final_path) + mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000)) + mime_type = mimetypes.guess_type(safe_name, strict=False)[0] + async with await create_session() as sess: + await ingest_fs_asset( + sess, + asset_hash=asset_hash, + abs_path=final_path, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=mime_type, + info_name=None, + tags=(), + ) + await sess.commit() + + return final_path diff --git a/app/assets_manager.py b/app/assets_manager.py index f6c839b8b..72d299467 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -26,7 +26,8 @@ from .database.services import ( create_asset_info_for_existing_asset, ) from .api import schemas_in, schemas_out -from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base +from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags +from .assets_fetcher import ensure_asset_cached async def asset_exists(*, asset_hash: str) -> bool: @@ -46,17 +47,17 @@ def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> file_name=asset_name, file_path=file_path, ) - except ValueError: - logging.exception("Cant parse '%s' as an asset file path.", file_path) + except ValueError as e: + logging.warning("Skipping non-asset path %s: %s", file_path, e) async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: """Adds a local asset to the DB. If already present and unchanged, does nothing. Notes: - - Uses absolute path as the canonical locator for the 'fs' backend. + - Uses absolute path as the canonical locator for the cache backend. - Computes BLAKE3 only when the fast existence check indicates it's needed. - - This function ensures the identity row and seeds mtime in asset_locator_state. + - This function ensures the identity row and seeds mtime in asset_cache_state. """ abs_path = os.path.abspath(file_path) size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) @@ -125,7 +126,7 @@ async def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/v1/assets/{info.id}/content", + preview_url=f"/api/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, @@ -143,12 +144,11 @@ async def resolve_asset_content_for_download( *, asset_info_id: int ) -> tuple[str, str, str]: """ - Returns (abs_path, content_type, download_name) for the given AssetInfo id. + Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. Also touches last_access_time (only_if_newer). + Ensures the local cache is present (uses resolver if needed). Raises: - ValueError if AssetInfo not found - NotImplementedError for unsupported backend - FileNotFoundError if underlying file does not exist (fs backend) + ValueError if AssetInfo cannot be found """ async with await create_session() as session: pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) @@ -156,21 +156,19 @@ async def resolve_asset_content_for_download( raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset = pair + tag_names = await get_asset_tags(session, asset_info_id=info.id) - if asset.storage_backend != "fs": - # Future: support http/s3/gcs/... - raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet") - - abs_path = os.path.abspath(asset.storage_locator) - if not os.path.exists(abs_path): - raise FileNotFoundError(abs_path) + # Ensure cached (download if missing) + preferred_name = info.name or info.asset_hash.split(":", 1)[-1] + abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names) + async with await create_session() as session: await touch_asset_info_by_id(session, asset_info_id=asset_info_id) await session.commit() - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" - download_name = info.name or os.path.basename(abs_path) - return abs_path, ctype, download_name + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name async def upload_asset_from_temp_path( @@ -238,7 +236,7 @@ async def upload_asset_from_temp_path( added_by=None, require_existing_tags=False, ) - info_id = result.get("asset_info_id") + info_id = result["asset_info_id"] if not info_id: raise RuntimeError("failed to create asset metadata") @@ -260,7 +258,7 @@ async def upload_asset_from_temp_path( preview_hash=info.preview_hash, created_at=info.created_at, last_access_time=info.last_access_time, - created_new=True, + created_new=result["asset_created"], ) @@ -416,7 +414,7 @@ def _get_size_mtime_ns(path: str) -> tuple[int, int]: return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) -def _safe_filename(name: Optional[str] , fallback: str) -> str: +def _safe_filename(name: Optional[str], fallback: str) -> str: n = os.path.basename((name or "").strip() or fallback) if n: return n diff --git a/app/assets_scanner.py b/app/assets_scanner.py index 691472156..5bafd6bb7 100644 --- a/app/assets_scanner.py +++ b/app/assets_scanner.py @@ -147,9 +147,7 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None: prog.last_error = str(exc) finally: prog.finished_at = time.time() - t = RUNNING_TASKS.get(root) - if t and t.done(): - RUNNING_TASKS.pop(root, None) + RUNNING_TASKS.pop(root, None) async def _scan_models(prog: ScanProgress) -> None: diff --git a/app/database/models.py b/app/database/models.py index 20b88ca68..d964a5226 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -45,8 +45,6 @@ class Asset(Base): size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) - storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") - storage_locator: Mapped[str] = mapped_column(Text, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -71,48 +69,71 @@ class Asset(Base): viewonly=True, ) - locator_state: Mapped["AssetLocatorState | None"] = relationship( + cache_state: Mapped["AssetCacheState | None"] = relationship( back_populates="asset", uselist=False, cascade="all, delete-orphan", passive_deletes=True, ) + locations: Mapped[list["AssetLocation"]] = relationship( + back_populates="asset", + cascade="all, delete-orphan", + passive_deletes=True, + ) + __table_args__ = ( Index("ix_assets_mime_type", "mime_type"), - Index("ix_assets_backend_locator", "storage_backend", "storage_locator"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" -class AssetLocatorState(Base): - __tablename__ = "asset_locator_state" +class AssetCacheState(Base): + __tablename__ = "asset_cache_state" asset_hash: Mapped[str] = mapped_column( String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True ) - # For fs backends: nanosecond mtime; nullable if not applicable + file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) - # For HTTP/S3/GCS/Azure, etc.: optional validators - etag: Mapped[str | None] = mapped_column(String(256), nullable=True) - last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) - asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False) + asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False) __table_args__ = ( - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), + Index("ix_asset_cache_state_file_path", "file_path"), + CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), ) def to_dict(self, include_none: bool = False) -> dict[str, Any]: return to_dict(self, include_none=include_none) def __repr__(self) -> str: - return f"" + return f"" + + +class AssetLocation(Base): + __tablename__ = "asset_locations" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False) + provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs" + locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object" + expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True) + etag: Mapped[str | None] = mapped_column(String(256), nullable=True) + last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True) + + asset: Mapped["Asset"] = relationship(back_populates="locations") + + __table_args__ = ( + UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"), + Index("ix_asset_locations_hash", "asset_hash"), + Index("ix_asset_locations_provider", "provider"), + ) class AssetInfo(Base): @@ -220,7 +241,7 @@ class AssetInfoTag(Base): Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True ) tag_name: Mapped[str] = mapped_column( - String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True + String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True ) origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") added_by: Mapped[str | None] = mapped_column(String(128)) @@ -240,7 +261,7 @@ class AssetInfoTag(Base): class Tag(Base): __tablename__ = "tags" - name: Mapped[str] = mapped_column(String(128), primary_key=True) + name: Mapped[str] = mapped_column(String(512), primary_key=True) tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( diff --git a/app/database/services.py b/app/database/services.py index 34029b139..5f1ffffbf 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -12,7 +12,7 @@ from sqlalchemy import select, delete, exists, func from sqlalchemy.orm import contains_eager from sqlalchemy.exc import IntegrityError -from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta +from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow from .._assets_helpers import normalize_tags @@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick( mtime_ns: Optional[int] = None, ) -> bool: """ - Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path, + Returns 'True' if there is already AssetCacheState record that matches this absolute path, AND (if provided) mtime_ns matches stored locator-state, AND (if provided) size_bytes matches verified size when known. """ locator = os.path.abspath(file_path) - stmt = select(sa.literal(True)).select_from(Asset) + stmt = select(sa.literal(True)).select_from(AssetCacheState).join( + Asset, Asset.hash == AssetCacheState.asset_hash + ).where(AssetCacheState.file_path == locator).limit(1) - conditions = [ - Asset.storage_backend == "fs", - Asset.storage_locator == locator, - ] - - # If size_bytes provided require equality when the asset has a verified (non-zero) size. - # If verified size is 0 (unknown), we don't force equality. - if size_bytes is not None: - conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - - # If mtime_ns provided require the locator-state to exist and match. + conds = [] if mtime_ns is not None: - stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash) - conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns)) + conds.append(AssetCacheState.mtime_ns == int(mtime_ns)) + if size_bytes is not None: + conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes))) - stmt = stmt.where(*conditions).limit(1) + if conds: + stmt = stmt.where(*conds) row = (await session.execute(stmt)).first() return row is not None @@ -85,11 +79,11 @@ async def ingest_fs_asset( require_existing_tags: bool = False, ) -> dict: """ - Creates or updates Asset record for a local (fs) asset. + Upsert Asset identity row + cache state pointing at local file. Always: - Insert Asset if missing; else update size_bytes (and updated_at) if different. - - Insert AssetLocatorState if missing; else update mtime_ns if different. + - Insert AssetCacheState if missing; else update mtime_ns if different. Optionally (when info_name is provided): - Create an AssetInfo (no refcount changes). @@ -126,8 +120,6 @@ async def ingest_fs_asset( size_bytes=int(size_bytes), mime_type=mime_type, refcount=0, - storage_backend="fs", - storage_locator=locator, created_at=datetime_now, updated_at=datetime_now, ) @@ -145,21 +137,19 @@ async def ingest_fs_asset( if mime_type and existing.mime_type != mime_type: existing.mime_type = mime_type changed = True - if existing.storage_locator != locator: - existing.storage_locator = locator - changed = True if changed: existing.updated_at = datetime_now out["asset_updated"] = True else: logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) - # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- + # ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ---- with contextlib.suppress(IntegrityError): async with session.begin_nested(): session.add( - AssetLocatorState( + AssetCacheState( asset_hash=asset_hash, + file_path=locator, mtime_ns=int(mtime_ns), ) ) @@ -167,11 +157,17 @@ async def ingest_fs_asset( out["state_created"] = True if not out["state_created"]: - state = await session.get(AssetLocatorState, asset_hash) + state = await session.get(AssetCacheState, asset_hash) if state is not None: - desired_mtime = int(mtime_ns) - if state.mtime_ns != desired_mtime: - state.mtime_ns = desired_mtime + changed = False + if state.file_path != locator: + state.file_path = locator + changed = True + if state.mtime_ns != int(mtime_ns): + state.mtime_ns = int(mtime_ns) + changed = True + if changed: + await session.flush() out["state_updated"] = True else: logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) @@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path( stmt = sa.update(AssetInfo).where( sa.exists( sa.select(sa.literal(1)) - .select_from(Asset) + .select_from(AssetCacheState) .where( - Asset.hash == AssetInfo.asset_hash, - Asset.storage_backend == "fs", - Asset.storage_locator == locator, + AssetCacheState.asset_hash == AssetInfo.asset_hash, + AssetCacheState.file_path == locator, ) ) ) @@ -337,13 +332,6 @@ async def list_asset_infos_page( We purposely collect tags in a separate (single) query to avoid row explosion. """ - # Clamp - if limit <= 0: - limit = 1 - if limit > 100: - limit = 100 - if offset < 0: - offset = 0 # Build base query base = ( @@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in return pair[0], pair[1] +async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]: + return await session.get(AssetCacheState, asset_hash) + + +async def list_asset_locations( + session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None +) -> list[AssetLocation]: + stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash) + if provider: + stmt = stmt.where(AssetLocation.provider == provider) + return (await session.execute(stmt)).scalars().all() + + +async def upsert_asset_location( + session: AsyncSession, + *, + asset_hash: str, + provider: str, + locator: str, + expected_size_bytes: Optional[int] = None, + etag: Optional[str] = None, + last_modified: Optional[str] = None, +) -> AssetLocation: + loc = ( + await session.execute( + select(AssetLocation).where( + AssetLocation.asset_hash == asset_hash, + AssetLocation.provider == provider, + AssetLocation.locator == locator, + ).limit(1) + ) + ).scalars().first() + if loc: + changed = False + if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes: + loc.expected_size_bytes = expected_size_bytes + changed = True + if etag is not None and loc.etag != etag: + loc.etag = etag + changed = True + if last_modified is not None and loc.last_modified != last_modified: + loc.last_modified = last_modified + changed = True + if changed: + await session.flush() + return loc + + loc = AssetLocation( + asset_hash=asset_hash, + provider=provider, + locator=locator, + expected_size_bytes=expected_size_bytes, + etag=etag, + last_modified=last_modified, + ) + session.add(loc) + await session.flush() + return loc + + async def create_asset_info_for_existing_asset( session: AsyncSession, *, @@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]: rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) elif isinstance(value, (int, float, Decimal)): # store numeric; SQLAlchemy will coerce to Numeric - rows.append({"key": key, "ordinal": 0, "val_num": value}) + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) elif isinstance(value, str): rows.append({"key": key, "ordinal": 0, "val_str": value}) else: @@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]: elif isinstance(x, bool): rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) elif isinstance(x, (int, float, Decimal)): - rows.append({"key": key, "ordinal": i, "val_num": x}) + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) elif isinstance(x, str): rows.append({"key": key, "ordinal": i, "val_str": x}) else: diff --git a/app/resolvers/__init__.py b/app/resolvers/__init__.py new file mode 100644 index 000000000..c489ebad7 --- /dev/null +++ b/app/resolvers/__init__.py @@ -0,0 +1,35 @@ +import contextlib +from dataclasses import dataclass +from typing import Protocol, Optional, Mapping + + +@dataclass +class ResolveResult: + provider: str # e.g., "gcs" + download_url: str # fully-qualified URL to fetch bytes + headers: Mapping[str, str] # optional auth headers etc + expected_size: Optional[int] = None + tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"] + filename: Optional[str] = None # preferred basename + +class AssetResolver(Protocol): + provider: str + async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ... + + +_REGISTRY: list[AssetResolver] = [] + + +def register_resolver(resolver: AssetResolver) -> None: + """Append Resolver with simple de-dup per provider.""" + global _REGISTRY + _REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver] + + +async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]: + for r in _REGISTRY: + with contextlib.suppress(Exception): # For Resolver failure we just try the next one + res = await r.resolve(asset_hash) + if res: + return res + return None