add AssetsResolver support

This commit is contained in:
bigcat88 2025-08-26 14:19:56 +03:00
parent a763cbd39d
commit 6fade5da38
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
9 changed files with 371 additions and 106 deletions

View File

@ -1,3 +1,4 @@
# File: /alembic_db/versions/0001_assets.py
"""initial assets schema + per-asset state cache """initial assets schema + per-asset state cache
Revision ID: 0001_assets Revision ID: 0001_assets
@ -22,15 +23,12 @@ def upgrade() -> None:
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
sa.Column("storage_locator", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
) )
op.create_index("ix_assets_mime_type", "assets", ["mime_type"]) op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
# ASSETS_INFO: user-visible references (mutable metadata) # ASSETS_INFO: user-visible references (mutable metadata)
op.create_table( op.create_table(
@ -52,11 +50,12 @@ def upgrade() -> None:
op.create_index("ix_assets_info_name", "assets_info", ["name"]) op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"]) op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"]) op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary # TAGS: normalized tag vocabulary
op.create_table( op.create_table(
"tags", "tags",
sa.Column("name", sa.String(length=128), primary_key=True), sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"), sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"), sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
) )
@ -65,8 +64,8 @@ def upgrade() -> None:
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo # ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table( op.create_table(
"asset_info_tags", "asset_info_tags",
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_by", sa.String(length=128), nullable=True),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
@ -75,15 +74,15 @@ def upgrade() -> None:
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"]) op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records # ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset
op.create_table( op.create_table(
"asset_locator_state", "asset_cache_state",
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True), sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("etag", sa.String(length=256), nullable=True), sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
) )
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting # ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table( op.create_table(
@ -102,6 +101,21 @@ def upgrade() -> None:
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"]) op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"]) op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# ASSET_LOCATIONS: remote locations per asset
op.create_table(
"asset_locations",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs"
sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob"
sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True),
sa.Column("etag", sa.String(length=256), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
)
op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"])
op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"])
# Tags vocabulary for models # Tags vocabulary for models
tags_table = sa.table( tags_table = sa.table(
"tags", "tags",
@ -143,13 +157,18 @@ def upgrade() -> None:
def downgrade() -> None: def downgrade() -> None:
op.drop_index("ix_asset_locations_provider", table_name="asset_locations")
op.drop_index("ix_asset_locations_hash", table_name="asset_locations")
op.drop_table("asset_locations")
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta") op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta") op.drop_table("asset_info_meta")
op.drop_table("asset_locator_state") op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags") op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags") op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
@ -159,6 +178,7 @@ def downgrade() -> None:
op.drop_table("tags") op.drop_table("tags")
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info")
@ -166,6 +186,5 @@ def downgrade() -> None:
op.drop_index("ix_assets_info_owner_id", table_name="assets_info") op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info") op.drop_table("assets_info")
op.drop_index("ix_assets_backend_locator", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets") op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets") op.drop_table("assets")

View File

@ -105,7 +105,10 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
if root == "models": if root == "models":
if len(tags) < 2: if len(tags) < 2:
raise ValueError("at least two tags required for model asset") raise ValueError("at least two tags required for model asset")
bases = folder_paths.folder_names_and_paths[tags[1]][0] try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases: if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'") raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0]) base_dir = os.path.abspath(bases[0])

View File

@ -9,7 +9,7 @@ from pydantic import ValidationError
import folder_paths import folder_paths
from .. import assets_manager, assets_scanner from .. import assets_manager, assets_scanner
from . import schemas_in from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
@ -20,6 +20,9 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower() hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str: if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'") return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await assets_manager.asset_exists(asset_hash=hash_str) exists = await assets_manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404) return web.Response(status=200 if exists else 404)
@ -69,7 +72,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
except FileNotFoundError: except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = filename.replace('"', "'") quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path) resp = web.FileResponse(abs_path)
@ -115,6 +118,7 @@ async def upload_asset(request: web.Request) -> web.Response:
user_metadata_raw: Optional[str] = None user_metadata_raw: Optional[str] = None
file_written = 0 file_written = 0
tmp_path: Optional[str] = None
while True: while True:
field = await reader.next() field = await reader.next()
if field is None: if field is None:
@ -173,6 +177,8 @@ async def upload_asset(request: web.Request) -> web.Response:
return _validation_error_response("INVALID_BODY", ve) return _validation_error_response("INVALID_BODY", ve)
if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths: if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'") return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'")
try: try:
@ -182,12 +188,14 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=file_client_name, client_filename=file_client_name,
) )
return web.json_response(created.model_dump(mode="json"), status=201) return web.json_response(created.model_dump(mode="json"), status=201)
except ValueError:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception: except Exception:
try: if tmp_path and os.path.exists(tmp_path):
if os.path.exists(tmp_path): os.remove(tmp_path)
os.remove(tmp_path) return _error_response(500, "INTERNAL", "Unexpected server error.")
finally:
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put("/api/assets/{id}") @ROUTES.put("/api/assets/{id}")
@ -341,6 +349,7 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
states = assets_scanner.current_statuses() states = assets_scanner.current_statuses()
if root in {"models", "input", "output"}: if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200) return web.json_response(states.model_dump(mode="json"), status=200)

132
app/assets_fetcher.py Normal file
View File

@ -0,0 +1,132 @@
from __future__ import annotations
import asyncio
import os
import tempfile
from typing import Optional
import mimetypes
import aiohttp
from .storage.hashing import blake3_hash_sync
from .database.db import create_session
from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash
from .resolvers import resolve_asset
from ._assets_helpers import resolve_destination_from_tags, ensure_within_base
_FETCH_LOCKS: dict[str, asyncio.Lock] = {}
def _sanitize_filename(name: str) -> str:
return os.path.basename((name or "").strip()) or "file"
async def ensure_asset_cached(
asset_hash: str,
*,
preferred_name: Optional[str] = None,
tags_hint: Optional[list[str]] = None,
) -> str:
"""
Ensure there is a verified local file for `asset_hash` in the correct Comfy folder.
Policy:
- Resolver must provide valid tags (root and, for models, category).
- If target path already exists:
* if hash matches -> reuse & ingest
* else -> remove and overwrite with the correct content
"""
lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock())
async with lock:
# 1) If we already have a state -> trust the path
async with await create_session() as sess:
state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash)
if state and os.path.isfile(state.file_path):
return state.file_path
# 2) Resolve remote location + placement hints (must include valid tags)
res = await resolve_asset(asset_hash)
if not res:
raise FileNotFoundError(f"No resolver/locations for {asset_hash}")
placement_tags = tags_hint or res.tags
if not placement_tags:
raise ValueError(f"Resolver did not provide placement tags for {asset_hash}")
name_hint = res.filename or preferred_name or asset_hash.replace(":", "_")
safe_name = _sanitize_filename(name_hint)
# 3) Map tags -> destination (strict: raises if invalid root or models category)
base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
final_path = os.path.abspath(os.path.join(dest_dir, safe_name))
ensure_within_base(final_path, base_dir)
# 4) If target path exists, try to reuse; else delete invalid cache
if os.path.exists(final_path) and os.path.isfile(final_path):
existing_digest = blake3_hash_sync(final_path)
if f"blake3:{existing_digest}" == asset_hash:
size_bytes = os.path.getsize(final_path)
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
async with await create_session() as sess:
await ingest_fs_asset(
sess,
asset_hash=asset_hash,
abs_path=final_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=None,
tags=(),
)
await sess.commit()
return final_path
else:
# Invalid cache: remove before re-downloading
os.remove(final_path)
# 5) Download to temp next to destination
timeout = aiohttp.ClientTimeout(total=60 * 30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(res.download_url, headers=dict(res.headers)) as resp:
resp.raise_for_status()
cl = resp.headers.get("Content-Length")
if res.expected_size and cl and int(cl) != int(res.expected_size):
raise ValueError("server Content-Length does not match expected size")
with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp:
tmp_path = tmp.name
async for chunk in resp.content.iter_chunked(8 * 1024 * 1024):
if chunk:
tmp.write(chunk)
# 6) Verify content hash
digest = blake3_hash_sync(tmp_path)
canonical = f"blake3:{digest}"
if canonical != asset_hash:
try:
os.remove(tmp_path)
finally:
raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}")
# 7) Atomically move into place (we already removed an invalid file if it existed)
if os.path.exists(final_path):
os.remove(final_path)
os.replace(tmp_path, final_path)
# 8) Record identity + cache state (+ mime type)
size_bytes = os.path.getsize(final_path)
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
mime_type = mimetypes.guess_type(safe_name, strict=False)[0]
async with await create_session() as sess:
await ingest_fs_asset(
sess,
asset_hash=asset_hash,
abs_path=final_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=mime_type,
info_name=None,
tags=(),
)
await sess.commit()
return final_path

View File

@ -26,7 +26,8 @@ from .database.services import (
create_asset_info_for_existing_asset, create_asset_info_for_existing_asset,
) )
from .api import schemas_in, schemas_out from .api import schemas_in, schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
from .assets_fetcher import ensure_asset_cached
async def asset_exists(*, asset_hash: str) -> bool: async def asset_exists(*, asset_hash: str) -> bool:
@ -46,17 +47,17 @@ def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) ->
file_name=asset_name, file_name=asset_name,
file_path=file_path, file_path=file_path,
) )
except ValueError: except ValueError as e:
logging.exception("Cant parse '%s' as an asset file path.", file_path) logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None: async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
"""Adds a local asset to the DB. If already present and unchanged, does nothing. """Adds a local asset to the DB. If already present and unchanged, does nothing.
Notes: Notes:
- Uses absolute path as the canonical locator for the 'fs' backend. - Uses absolute path as the canonical locator for the cache backend.
- Computes BLAKE3 only when the fast existence check indicates it's needed. - Computes BLAKE3 only when the fast existence check indicates it's needed.
- This function ensures the identity row and seeds mtime in asset_locator_state. - This function ensures the identity row and seeds mtime in asset_cache_state.
""" """
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path) size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
@ -125,7 +126,7 @@ async def list_assets(
size=int(asset.size_bytes) if asset else None, size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None, mime_type=asset.mime_type if asset else None,
tags=tags, tags=tags,
preview_url=f"/api/v1/assets/{info.id}/content", preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at, created_at=info.created_at,
updated_at=info.updated_at, updated_at=info.updated_at,
last_access_time=info.last_access_time, last_access_time=info.last_access_time,
@ -143,12 +144,11 @@ async def resolve_asset_content_for_download(
*, asset_info_id: int *, asset_info_id: int
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
""" """
Returns (abs_path, content_type, download_name) for the given AssetInfo id. Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
Also touches last_access_time (only_if_newer). Also touches last_access_time (only_if_newer).
Ensures the local cache is present (uses resolver if needed).
Raises: Raises:
ValueError if AssetInfo not found ValueError if AssetInfo cannot be found
NotImplementedError for unsupported backend
FileNotFoundError if underlying file does not exist (fs backend)
""" """
async with await create_session() as session: async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
@ -156,21 +156,19 @@ async def resolve_asset_content_for_download(
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
if asset.storage_backend != "fs": # Ensure cached (download if missing)
# Future: support http/s3/gcs/... preferred_name = info.name or info.asset_hash.split(":", 1)[-1]
raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet") abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names)
abs_path = os.path.abspath(asset.storage_locator)
if not os.path.exists(abs_path):
raise FileNotFoundError(abs_path)
async with await create_session() as session:
await touch_asset_info_by_id(session, asset_info_id=asset_info_id) await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit() await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path) download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name return abs_path, ctype, download_name
async def upload_asset_from_temp_path( async def upload_asset_from_temp_path(
@ -238,7 +236,7 @@ async def upload_asset_from_temp_path(
added_by=None, added_by=None,
require_existing_tags=False, require_existing_tags=False,
) )
info_id = result.get("asset_info_id") info_id = result["asset_info_id"]
if not info_id: if not info_id:
raise RuntimeError("failed to create asset metadata") raise RuntimeError("failed to create asset metadata")
@ -260,7 +258,7 @@ async def upload_asset_from_temp_path(
preview_hash=info.preview_hash, preview_hash=info.preview_hash,
created_at=info.created_at, created_at=info.created_at,
last_access_time=info.last_access_time, last_access_time=info.last_access_time,
created_new=True, created_new=result["asset_created"],
) )
@ -416,7 +414,7 @@ def _get_size_mtime_ns(path: str) -> tuple[int, int]:
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str] , fallback: str) -> str: def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback) n = os.path.basename((name or "").strip() or fallback)
if n: if n:
return n return n

View File

@ -147,9 +147,7 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.last_error = str(exc) prog.last_error = str(exc)
finally: finally:
prog.finished_at = time.time() prog.finished_at = time.time()
t = RUNNING_TASKS.get(root) RUNNING_TASKS.pop(root, None)
if t and t.done():
RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None: async def _scan_models(prog: ScanProgress) -> None:

View File

@ -45,8 +45,6 @@ class Asset(Base):
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255)) mime_type: Mapped[str | None] = mapped_column(String(255))
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow DateTime(timezone=False), nullable=False, default=utcnow
) )
@ -71,48 +69,71 @@ class Asset(Base):
viewonly=True, viewonly=True,
) )
locator_state: Mapped["AssetLocatorState | None"] = relationship( cache_state: Mapped["AssetCacheState | None"] = relationship(
back_populates="asset", back_populates="asset",
uselist=False, uselist=False,
cascade="all, delete-orphan", cascade="all, delete-orphan",
passive_deletes=True, passive_deletes=True,
) )
locations: Mapped[list["AssetLocation"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = ( __table_args__ = (
Index("ix_assets_mime_type", "mime_type"), Index("ix_assets_mime_type", "mime_type"),
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
) )
def to_dict(self, include_none: bool = False) -> dict[str, Any]: def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none) return to_dict(self, include_none=include_none)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>" return f"<Asset hash={self.hash[:12]}>"
class AssetLocatorState(Base): class AssetCacheState(Base):
__tablename__ = "asset_locator_state" __tablename__ = "asset_cache_state"
asset_hash: Mapped[str] = mapped_column( asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
) )
# For fs backends: nanosecond mtime; nullable if not applicable file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
# For HTTP/S3/GCS/Azure, etc.: optional validators
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False) asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
__table_args__ = ( __table_args__ = (
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"), Index("ix_asset_cache_state_file_path", "file_path"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
) )
def to_dict(self, include_none: bool = False) -> dict[str, Any]: def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none) return to_dict(self, include_none=include_none)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>" return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
class AssetLocation(Base):
__tablename__ = "asset_locations"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locations")
__table_args__ = (
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
Index("ix_asset_locations_hash", "asset_hash"),
Index("ix_asset_locations_provider", "provider"),
)
class AssetInfo(Base): class AssetInfo(Base):
@ -220,7 +241,7 @@ class AssetInfoTag(Base):
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
) )
tag_name: Mapped[str] = mapped_column( tag_name: Mapped[str] = mapped_column(
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
) )
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128)) added_by: Mapped[str | None] = mapped_column(String(128))
@ -240,7 +261,7 @@ class AssetInfoTag(Base):
class Tag(Base): class Tag(Base):
__tablename__ = "tags" __tablename__ = "tags"
name: Mapped[str] = mapped_column(String(128), primary_key=True) name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship( asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(

View File

@ -12,7 +12,7 @@ from sqlalchemy import select, delete, exists, func
from sqlalchemy.orm import contains_eager from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
from .timeutil import utcnow from .timeutil import utcnow
from .._assets_helpers import normalize_tags from .._assets_helpers import normalize_tags
@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick(
mtime_ns: Optional[int] = None, mtime_ns: Optional[int] = None,
) -> bool: ) -> bool:
""" """
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path, Returns 'True' if there is already AssetCacheState record that matches this absolute path,
AND (if provided) mtime_ns matches stored locator-state, AND (if provided) mtime_ns matches stored locator-state,
AND (if provided) size_bytes matches verified size when known. AND (if provided) size_bytes matches verified size when known.
""" """
locator = os.path.abspath(file_path) locator = os.path.abspath(file_path)
stmt = select(sa.literal(True)).select_from(Asset) stmt = select(sa.literal(True)).select_from(AssetCacheState).join(
Asset, Asset.hash == AssetCacheState.asset_hash
).where(AssetCacheState.file_path == locator).limit(1)
conditions = [ conds = []
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
]
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
# If verified size is 0 (unknown), we don't force equality.
if size_bytes is not None:
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
# If mtime_ns provided require the locator-state to exist and match.
if mtime_ns is not None: if mtime_ns is not None:
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash) conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns)) if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
stmt = stmt.where(*conditions).limit(1) if conds:
stmt = stmt.where(*conds)
row = (await session.execute(stmt)).first() row = (await session.execute(stmt)).first()
return row is not None return row is not None
@ -85,11 +79,11 @@ async def ingest_fs_asset(
require_existing_tags: bool = False, require_existing_tags: bool = False,
) -> dict: ) -> dict:
""" """
Creates or updates Asset record for a local (fs) asset. Upsert Asset identity row + cache state pointing at local file.
Always: Always:
- Insert Asset if missing; else update size_bytes (and updated_at) if different. - Insert Asset if missing; else update size_bytes (and updated_at) if different.
- Insert AssetLocatorState if missing; else update mtime_ns if different. - Insert AssetCacheState if missing; else update mtime_ns if different.
Optionally (when info_name is provided): Optionally (when info_name is provided):
- Create an AssetInfo (no refcount changes). - Create an AssetInfo (no refcount changes).
@ -126,8 +120,6 @@ async def ingest_fs_asset(
size_bytes=int(size_bytes), size_bytes=int(size_bytes),
mime_type=mime_type, mime_type=mime_type,
refcount=0, refcount=0,
storage_backend="fs",
storage_locator=locator,
created_at=datetime_now, created_at=datetime_now,
updated_at=datetime_now, updated_at=datetime_now,
) )
@ -145,21 +137,19 @@ async def ingest_fs_asset(
if mime_type and existing.mime_type != mime_type: if mime_type and existing.mime_type != mime_type:
existing.mime_type = mime_type existing.mime_type = mime_type
changed = True changed = True
if existing.storage_locator != locator:
existing.storage_locator = locator
changed = True
if changed: if changed:
existing.updated_at = datetime_now existing.updated_at = datetime_now
out["asset_updated"] = True out["asset_updated"] = True
else: else:
logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash)
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- # ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ----
with contextlib.suppress(IntegrityError): with contextlib.suppress(IntegrityError):
async with session.begin_nested(): async with session.begin_nested():
session.add( session.add(
AssetLocatorState( AssetCacheState(
asset_hash=asset_hash, asset_hash=asset_hash,
file_path=locator,
mtime_ns=int(mtime_ns), mtime_ns=int(mtime_ns),
) )
) )
@ -167,11 +157,17 @@ async def ingest_fs_asset(
out["state_created"] = True out["state_created"] = True
if not out["state_created"]: if not out["state_created"]:
state = await session.get(AssetLocatorState, asset_hash) state = await session.get(AssetCacheState, asset_hash)
if state is not None: if state is not None:
desired_mtime = int(mtime_ns) changed = False
if state.mtime_ns != desired_mtime: if state.file_path != locator:
state.mtime_ns = desired_mtime state.file_path = locator
changed = True
if state.mtime_ns != int(mtime_ns):
state.mtime_ns = int(mtime_ns)
changed = True
if changed:
await session.flush()
out["state_updated"] = True out["state_updated"] = True
else: else:
logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash)
@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path(
stmt = sa.update(AssetInfo).where( stmt = sa.update(AssetInfo).where(
sa.exists( sa.exists(
sa.select(sa.literal(1)) sa.select(sa.literal(1))
.select_from(Asset) .select_from(AssetCacheState)
.where( .where(
Asset.hash == AssetInfo.asset_hash, AssetCacheState.asset_hash == AssetInfo.asset_hash,
Asset.storage_backend == "fs", AssetCacheState.file_path == locator,
Asset.storage_locator == locator,
) )
) )
) )
@ -337,13 +332,6 @@ async def list_asset_infos_page(
We purposely collect tags in a separate (single) query to avoid row explosion. We purposely collect tags in a separate (single) query to avoid row explosion.
""" """
# Clamp
if limit <= 0:
limit = 1
if limit > 100:
limit = 100
if offset < 0:
offset = 0
# Build base query # Build base query
base = ( base = (
@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in
return pair[0], pair[1] return pair[0], pair[1]
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
return await session.get(AssetCacheState, asset_hash)
async def list_asset_locations(
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
) -> list[AssetLocation]:
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
if provider:
stmt = stmt.where(AssetLocation.provider == provider)
return (await session.execute(stmt)).scalars().all()
async def upsert_asset_location(
session: AsyncSession,
*,
asset_hash: str,
provider: str,
locator: str,
expected_size_bytes: Optional[int] = None,
etag: Optional[str] = None,
last_modified: Optional[str] = None,
) -> AssetLocation:
loc = (
await session.execute(
select(AssetLocation).where(
AssetLocation.asset_hash == asset_hash,
AssetLocation.provider == provider,
AssetLocation.locator == locator,
).limit(1)
)
).scalars().first()
if loc:
changed = False
if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes:
loc.expected_size_bytes = expected_size_bytes
changed = True
if etag is not None and loc.etag != etag:
loc.etag = etag
changed = True
if last_modified is not None and loc.last_modified != last_modified:
loc.last_modified = last_modified
changed = True
if changed:
await session.flush()
return loc
loc = AssetLocation(
asset_hash=asset_hash,
provider=provider,
locator=locator,
expected_size_bytes=expected_size_bytes,
etag=etag,
last_modified=last_modified,
)
session.add(loc)
await session.flush()
return loc
async def create_asset_info_for_existing_asset( async def create_asset_info_for_existing_asset(
session: AsyncSession, session: AsyncSession,
*, *,
@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)): elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric # store numeric; SQLAlchemy will coerce to Numeric
rows.append({"key": key, "ordinal": 0, "val_num": value}) num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str): elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value}) rows.append({"key": key, "ordinal": 0, "val_str": value})
else: else:
@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
elif isinstance(x, bool): elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)): elif isinstance(x, (int, float, Decimal)):
rows.append({"key": key, "ordinal": i, "val_num": x}) num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str): elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x}) rows.append({"key": key, "ordinal": i, "val_str": x})
else: else:

35
app/resolvers/__init__.py Normal file
View File

@ -0,0 +1,35 @@
import contextlib
from dataclasses import dataclass
from typing import Protocol, Optional, Mapping
@dataclass
class ResolveResult:
provider: str # e.g., "gcs"
download_url: str # fully-qualified URL to fetch bytes
headers: Mapping[str, str] # optional auth headers etc
expected_size: Optional[int] = None
tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"]
filename: Optional[str] = None # preferred basename
class AssetResolver(Protocol):
provider: str
async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ...
_REGISTRY: list[AssetResolver] = []
def register_resolver(resolver: AssetResolver) -> None:
"""Append Resolver with simple de-dup per provider."""
global _REGISTRY
_REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver]
async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]:
for r in _REGISTRY:
with contextlib.suppress(Exception): # For Resolver failure we just try the next one
res = await r.resolve(asset_hash)
if res:
return res
return None