add AssetsResolver support

This commit is contained in:
bigcat88 2025-08-26 14:19:56 +03:00
parent a763cbd39d
commit 6fade5da38
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
9 changed files with 371 additions and 106 deletions

View File

@ -1,3 +1,4 @@
# File: /alembic_db/versions/0001_assets.py
"""initial assets schema + per-asset state cache
Revision ID: 0001_assets
@ -22,15 +23,12 @@ def upgrade() -> None:
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("mime_type", sa.String(length=255), nullable=True),
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
sa.Column("storage_locator", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
)
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
# ASSETS_INFO: user-visible references (mutable metadata)
op.create_table(
@ -52,11 +50,12 @@ def upgrade() -> None:
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
# TAGS: normalized tag vocabulary
op.create_table(
"tags",
sa.Column("name", sa.String(length=128), primary_key=True),
sa.Column("name", sa.String(length=512), primary_key=True),
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
)
@ -65,8 +64,8 @@ def upgrade() -> None:
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_by", sa.String(length=128), nullable=True),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
@ -75,15 +74,15 @@ def upgrade() -> None:
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records
# ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset
op.create_table(
"asset_locator_state",
"asset_cache_state",
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("etag", sa.String(length=256), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
@ -102,6 +101,21 @@ def upgrade() -> None:
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
# ASSET_LOCATIONS: remote locations per asset
op.create_table(
"asset_locations",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs"
sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob"
sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True),
sa.Column("etag", sa.String(length=256), nullable=True),
sa.Column("last_modified", sa.String(length=128), nullable=True),
sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
)
op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"])
op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"])
# Tags vocabulary for models
tags_table = sa.table(
"tags",
@ -143,13 +157,18 @@ def upgrade() -> None:
def downgrade() -> None:
op.drop_index("ix_asset_locations_provider", table_name="asset_locations")
op.drop_index("ix_asset_locations_hash", table_name="asset_locations")
op.drop_table("asset_locations")
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_table("asset_locator_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
@ -159,6 +178,7 @@ def downgrade() -> None:
op.drop_table("tags")
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
@ -166,6 +186,5 @@ def downgrade() -> None:
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
op.drop_index("ix_assets_backend_locator", table_name="assets")
op.drop_index("ix_assets_mime_type", table_name="assets")
op.drop_table("assets")

View File

@ -105,7 +105,10 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
bases = folder_paths.folder_names_and_paths[tags[1]][0]
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])

View File

@ -9,7 +9,7 @@ from pydantic import ValidationError
import folder_paths
from .. import assets_manager, assets_scanner
from . import schemas_in
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
@ -20,6 +20,9 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
hash_str = request.match_info.get("hash", "").strip().lower()
if not hash_str or ":" not in hash_str:
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
algo, digest = hash_str.split(":", 1)
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
exists = await assets_manager.asset_exists(asset_hash=hash_str)
return web.Response(status=200 if exists else 404)
@ -69,7 +72,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = filename.replace('"', "'")
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
@ -115,6 +118,7 @@ async def upload_asset(request: web.Request) -> web.Response:
user_metadata_raw: Optional[str] = None
file_written = 0
tmp_path: Optional[str] = None
while True:
field = await reader.next()
if field is None:
@ -173,6 +177,8 @@ async def upload_asset(request: web.Request) -> web.Response:
return _validation_error_response("INVALID_BODY", ve)
if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'")
try:
@ -182,12 +188,14 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=file_client_name,
)
return web.json_response(created.model_dump(mode="json"), status=201)
except ValueError:
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
finally:
return _error_response(500, "INTERNAL", "Unexpected server error.")
if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path)
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ROUTES.put("/api/assets/{id}")
@ -341,6 +349,7 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
states = assets_scanner.current_statuses()
if root in {"models", "input", "output"}:
states = [s for s in states.scans if s.root == root] # type: ignore
states = schemas_out.AssetScanStatusResponse(scans=states)
return web.json_response(states.model_dump(mode="json"), status=200)

132
app/assets_fetcher.py Normal file
View File

@ -0,0 +1,132 @@
from __future__ import annotations
import asyncio
import os
import tempfile
from typing import Optional
import mimetypes
import aiohttp
from .storage.hashing import blake3_hash_sync
from .database.db import create_session
from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash
from .resolvers import resolve_asset
from ._assets_helpers import resolve_destination_from_tags, ensure_within_base
_FETCH_LOCKS: dict[str, asyncio.Lock] = {}
def _sanitize_filename(name: str) -> str:
return os.path.basename((name or "").strip()) or "file"
async def ensure_asset_cached(
asset_hash: str,
*,
preferred_name: Optional[str] = None,
tags_hint: Optional[list[str]] = None,
) -> str:
"""
Ensure there is a verified local file for `asset_hash` in the correct Comfy folder.
Policy:
- Resolver must provide valid tags (root and, for models, category).
- If target path already exists:
* if hash matches -> reuse & ingest
* else -> remove and overwrite with the correct content
"""
lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock())
async with lock:
# 1) If we already have a state -> trust the path
async with await create_session() as sess:
state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash)
if state and os.path.isfile(state.file_path):
return state.file_path
# 2) Resolve remote location + placement hints (must include valid tags)
res = await resolve_asset(asset_hash)
if not res:
raise FileNotFoundError(f"No resolver/locations for {asset_hash}")
placement_tags = tags_hint or res.tags
if not placement_tags:
raise ValueError(f"Resolver did not provide placement tags for {asset_hash}")
name_hint = res.filename or preferred_name or asset_hash.replace(":", "_")
safe_name = _sanitize_filename(name_hint)
# 3) Map tags -> destination (strict: raises if invalid root or models category)
base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
final_path = os.path.abspath(os.path.join(dest_dir, safe_name))
ensure_within_base(final_path, base_dir)
# 4) If target path exists, try to reuse; else delete invalid cache
if os.path.exists(final_path) and os.path.isfile(final_path):
existing_digest = blake3_hash_sync(final_path)
if f"blake3:{existing_digest}" == asset_hash:
size_bytes = os.path.getsize(final_path)
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
async with await create_session() as sess:
await ingest_fs_asset(
sess,
asset_hash=asset_hash,
abs_path=final_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=None,
info_name=None,
tags=(),
)
await sess.commit()
return final_path
else:
# Invalid cache: remove before re-downloading
os.remove(final_path)
# 5) Download to temp next to destination
timeout = aiohttp.ClientTimeout(total=60 * 30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(res.download_url, headers=dict(res.headers)) as resp:
resp.raise_for_status()
cl = resp.headers.get("Content-Length")
if res.expected_size and cl and int(cl) != int(res.expected_size):
raise ValueError("server Content-Length does not match expected size")
with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp:
tmp_path = tmp.name
async for chunk in resp.content.iter_chunked(8 * 1024 * 1024):
if chunk:
tmp.write(chunk)
# 6) Verify content hash
digest = blake3_hash_sync(tmp_path)
canonical = f"blake3:{digest}"
if canonical != asset_hash:
try:
os.remove(tmp_path)
finally:
raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}")
# 7) Atomically move into place (we already removed an invalid file if it existed)
if os.path.exists(final_path):
os.remove(final_path)
os.replace(tmp_path, final_path)
# 8) Record identity + cache state (+ mime type)
size_bytes = os.path.getsize(final_path)
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
mime_type = mimetypes.guess_type(safe_name, strict=False)[0]
async with await create_session() as sess:
await ingest_fs_asset(
sess,
asset_hash=asset_hash,
abs_path=final_path,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=mime_type,
info_name=None,
tags=(),
)
await sess.commit()
return final_path

View File

@ -26,7 +26,8 @@ from .database.services import (
create_asset_info_for_existing_asset,
)
from .api import schemas_in, schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
from .assets_fetcher import ensure_asset_cached
async def asset_exists(*, asset_hash: str) -> bool:
@ -46,17 +47,17 @@ def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) ->
file_name=asset_name,
file_path=file_path,
)
except ValueError:
logging.exception("Cant parse '%s' as an asset file path.", file_path)
except ValueError as e:
logging.warning("Skipping non-asset path %s: %s", file_path, e)
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
Notes:
- Uses absolute path as the canonical locator for the 'fs' backend.
- Uses absolute path as the canonical locator for the cache backend.
- Computes BLAKE3 only when the fast existence check indicates it's needed.
- This function ensures the identity row and seeds mtime in asset_locator_state.
- This function ensures the identity row and seeds mtime in asset_cache_state.
"""
abs_path = os.path.abspath(file_path)
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
@ -125,7 +126,7 @@ async def list_assets(
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
preview_url=f"/api/v1/assets/{info.id}/content",
preview_url=f"/api/assets/{info.id}/content",
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
@ -143,12 +144,11 @@ async def resolve_asset_content_for_download(
*, asset_info_id: int
) -> tuple[str, str, str]:
"""
Returns (abs_path, content_type, download_name) for the given AssetInfo id.
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
Also touches last_access_time (only_if_newer).
Ensures the local cache is present (uses resolver if needed).
Raises:
ValueError if AssetInfo not found
NotImplementedError for unsupported backend
FileNotFoundError if underlying file does not exist (fs backend)
ValueError if AssetInfo cannot be found
"""
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
@ -156,21 +156,19 @@ async def resolve_asset_content_for_download(
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
tag_names = await get_asset_tags(session, asset_info_id=info.id)
if asset.storage_backend != "fs":
# Future: support http/s3/gcs/...
raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet")
abs_path = os.path.abspath(asset.storage_locator)
if not os.path.exists(abs_path):
raise FileNotFoundError(abs_path)
# Ensure cached (download if missing)
preferred_name = info.name or info.asset_hash.split(":", 1)[-1]
abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names)
async with await create_session() as session:
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def upload_asset_from_temp_path(
@ -238,7 +236,7 @@ async def upload_asset_from_temp_path(
added_by=None,
require_existing_tags=False,
)
info_id = result.get("asset_info_id")
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
@ -260,7 +258,7 @@ async def upload_asset_from_temp_path(
preview_hash=info.preview_hash,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=True,
created_new=result["asset_created"],
)
@ -416,7 +414,7 @@ def _get_size_mtime_ns(path: str) -> tuple[int, int]:
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: Optional[str] , fallback: str) -> str:
def _safe_filename(name: Optional[str], fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n

View File

@ -147,9 +147,7 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
prog.last_error = str(exc)
finally:
prog.finished_at = time.time()
t = RUNNING_TASKS.get(root)
if t and t.done():
RUNNING_TASKS.pop(root, None)
RUNNING_TASKS.pop(root, None)
async def _scan_models(prog: ScanProgress) -> None:

View File

@ -45,8 +45,6 @@ class Asset(Base):
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
@ -71,48 +69,71 @@ class Asset(Base):
viewonly=True,
)
locator_state: Mapped["AssetLocatorState | None"] = relationship(
cache_state: Mapped["AssetCacheState | None"] = relationship(
back_populates="asset",
uselist=False,
cascade="all, delete-orphan",
passive_deletes=True,
)
locations: Mapped[list["AssetLocation"]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("ix_assets_mime_type", "mime_type"),
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
return f"<Asset hash={self.hash[:12]}>"
class AssetLocatorState(Base):
__tablename__ = "asset_locator_state"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
# For fs backends: nanosecond mtime; nullable if not applicable
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
# For HTTP/S3/GCS/Azure, etc.: optional validators
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
__table_args__ = (
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
Index("ix_asset_cache_state_file_path", "file_path"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
class AssetLocation(Base):
__tablename__ = "asset_locations"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="locations")
__table_args__ = (
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
Index("ix_asset_locations_hash", "asset_hash"),
Index("ix_asset_locations_provider", "provider"),
)
class AssetInfo(Base):
@ -220,7 +241,7 @@ class AssetInfoTag(Base):
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
)
tag_name: Mapped[str] = mapped_column(
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128))
@ -240,7 +261,7 @@ class AssetInfoTag(Base):
class Tag(Base):
__tablename__ = "tags"
name: Mapped[str] = mapped_column(String(128), primary_key=True)
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(

View File

@ -12,7 +12,7 @@ from sqlalchemy import select, delete, exists, func
from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
from .timeutil import utcnow
from .._assets_helpers import normalize_tags
@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick(
mtime_ns: Optional[int] = None,
) -> bool:
"""
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
Returns 'True' if there is already AssetCacheState record that matches this absolute path,
AND (if provided) mtime_ns matches stored locator-state,
AND (if provided) size_bytes matches verified size when known.
"""
locator = os.path.abspath(file_path)
stmt = select(sa.literal(True)).select_from(Asset)
stmt = select(sa.literal(True)).select_from(AssetCacheState).join(
Asset, Asset.hash == AssetCacheState.asset_hash
).where(AssetCacheState.file_path == locator).limit(1)
conditions = [
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
]
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
# If verified size is 0 (unknown), we don't force equality.
if size_bytes is not None:
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
# If mtime_ns provided require the locator-state to exist and match.
conds = []
if mtime_ns is not None:
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
if size_bytes is not None:
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
stmt = stmt.where(*conditions).limit(1)
if conds:
stmt = stmt.where(*conds)
row = (await session.execute(stmt)).first()
return row is not None
@ -85,11 +79,11 @@ async def ingest_fs_asset(
require_existing_tags: bool = False,
) -> dict:
"""
Creates or updates Asset record for a local (fs) asset.
Upsert Asset identity row + cache state pointing at local file.
Always:
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
- Insert AssetLocatorState if missing; else update mtime_ns if different.
- Insert AssetCacheState if missing; else update mtime_ns if different.
Optionally (when info_name is provided):
- Create an AssetInfo (no refcount changes).
@ -126,8 +120,6 @@ async def ingest_fs_asset(
size_bytes=int(size_bytes),
mime_type=mime_type,
refcount=0,
storage_backend="fs",
storage_locator=locator,
created_at=datetime_now,
updated_at=datetime_now,
)
@ -145,21 +137,19 @@ async def ingest_fs_asset(
if mime_type and existing.mime_type != mime_type:
existing.mime_type = mime_type
changed = True
if existing.storage_locator != locator:
existing.storage_locator = locator
changed = True
if changed:
existing.updated_at = datetime_now
out["asset_updated"] = True
else:
logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash)
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
# ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ----
with contextlib.suppress(IntegrityError):
async with session.begin_nested():
session.add(
AssetLocatorState(
AssetCacheState(
asset_hash=asset_hash,
file_path=locator,
mtime_ns=int(mtime_ns),
)
)
@ -167,11 +157,17 @@ async def ingest_fs_asset(
out["state_created"] = True
if not out["state_created"]:
state = await session.get(AssetLocatorState, asset_hash)
state = await session.get(AssetCacheState, asset_hash)
if state is not None:
desired_mtime = int(mtime_ns)
if state.mtime_ns != desired_mtime:
state.mtime_ns = desired_mtime
changed = False
if state.file_path != locator:
state.file_path = locator
changed = True
if state.mtime_ns != int(mtime_ns):
state.mtime_ns = int(mtime_ns)
changed = True
if changed:
await session.flush()
out["state_updated"] = True
else:
logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash)
@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path(
stmt = sa.update(AssetInfo).where(
sa.exists(
sa.select(sa.literal(1))
.select_from(Asset)
.select_from(AssetCacheState)
.where(
Asset.hash == AssetInfo.asset_hash,
Asset.storage_backend == "fs",
Asset.storage_locator == locator,
AssetCacheState.asset_hash == AssetInfo.asset_hash,
AssetCacheState.file_path == locator,
)
)
)
@ -337,13 +332,6 @@ async def list_asset_infos_page(
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Clamp
if limit <= 0:
limit = 1
if limit > 100:
limit = 100
if offset < 0:
offset = 0
# Build base query
base = (
@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in
return pair[0], pair[1]
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
return await session.get(AssetCacheState, asset_hash)
async def list_asset_locations(
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
) -> list[AssetLocation]:
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
if provider:
stmt = stmt.where(AssetLocation.provider == provider)
return (await session.execute(stmt)).scalars().all()
async def upsert_asset_location(
session: AsyncSession,
*,
asset_hash: str,
provider: str,
locator: str,
expected_size_bytes: Optional[int] = None,
etag: Optional[str] = None,
last_modified: Optional[str] = None,
) -> AssetLocation:
loc = (
await session.execute(
select(AssetLocation).where(
AssetLocation.asset_hash == asset_hash,
AssetLocation.provider == provider,
AssetLocation.locator == locator,
).limit(1)
)
).scalars().first()
if loc:
changed = False
if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes:
loc.expected_size_bytes = expected_size_bytes
changed = True
if etag is not None and loc.etag != etag:
loc.etag = etag
changed = True
if last_modified is not None and loc.last_modified != last_modified:
loc.last_modified = last_modified
changed = True
if changed:
await session.flush()
return loc
loc = AssetLocation(
asset_hash=asset_hash,
provider=provider,
locator=locator,
expected_size_bytes=expected_size_bytes,
etag=etag,
last_modified=last_modified,
)
session.add(loc)
await session.flush()
return loc
async def create_asset_info_for_existing_asset(
session: AsyncSession,
*,
@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
# store numeric; SQLAlchemy will coerce to Numeric
rows.append({"key": key, "ordinal": 0, "val_num": value})
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
rows.append({"key": key, "ordinal": i, "val_num": x})
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:

35
app/resolvers/__init__.py Normal file
View File

@ -0,0 +1,35 @@
import contextlib
from dataclasses import dataclass
from typing import Protocol, Optional, Mapping
@dataclass
class ResolveResult:
provider: str # e.g., "gcs"
download_url: str # fully-qualified URL to fetch bytes
headers: Mapping[str, str] # optional auth headers etc
expected_size: Optional[int] = None
tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"]
filename: Optional[str] = None # preferred basename
class AssetResolver(Protocol):
provider: str
async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ...
_REGISTRY: list[AssetResolver] = []
def register_resolver(resolver: AssetResolver) -> None:
"""Append Resolver with simple de-dup per provider."""
global _REGISTRY
_REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver]
async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]:
for r in _REGISTRY:
with contextlib.suppress(Exception): # For Resolver failure we just try the next one
res = await r.resolve(asset_hash)
if res:
return res
return None