mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-16 01:25:08 +08:00
add AssetsResolver support
This commit is contained in:
parent
a763cbd39d
commit
6fade5da38
@ -1,3 +1,4 @@
|
||||
# File: /alembic_db/versions/0001_assets.py
|
||||
"""initial assets schema + per-asset state cache
|
||||
|
||||
Revision ID: 0001_assets
|
||||
@ -22,15 +23,12 @@ def upgrade() -> None:
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
|
||||
sa.Column("storage_locator", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
|
||||
)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
op.create_index("ix_assets_backend_locator", "assets", ["storage_backend", "storage_locator"])
|
||||
|
||||
# ASSETS_INFO: user-visible references (mutable metadata)
|
||||
op.create_table(
|
||||
@ -52,11 +50,12 @@ def upgrade() -> None:
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
# TAGS: normalized tag vocabulary
|
||||
op.create_table(
|
||||
"tags",
|
||||
sa.Column("name", sa.String(length=128), primary_key=True),
|
||||
sa.Column("name", sa.String(length=512), primary_key=True),
|
||||
sa.Column("tag_type", sa.String(length=32), nullable=False, server_default="user"),
|
||||
sa.CheckConstraint("name = lower(name)", name="ck_tags_lowercase"),
|
||||
)
|
||||
@ -65,8 +64,8 @@ def upgrade() -> None:
|
||||
# ASSET_INFO_TAGS: many-to-many for tags on AssetInfo
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("asset_info_id", sa.Integer(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_by", sa.String(length=128), nullable=True),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
@ -75,15 +74,15 @@ def upgrade() -> None:
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records
|
||||
# ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset
|
||||
op.create_table(
|
||||
"asset_locator_state",
|
||||
"asset_cache_state",
|
||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("etag", sa.String(length=256), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
@ -102,6 +101,21 @@ def upgrade() -> None:
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# ASSET_LOCATIONS: remote locations per asset
|
||||
op.create_table(
|
||||
"asset_locations",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("provider", sa.String(length=32), nullable=False), # e.g., "gcs"
|
||||
sa.Column("locator", sa.Text(), nullable=False), # e.g., "gs://bucket/path/to/blob"
|
||||
sa.Column("expected_size_bytes", sa.BigInteger(), nullable=True),
|
||||
sa.Column("etag", sa.String(length=256), nullable=True),
|
||||
sa.Column("last_modified", sa.String(length=128), nullable=True),
|
||||
sa.UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
|
||||
)
|
||||
op.create_index("ix_asset_locations_hash", "asset_locations", ["asset_hash"])
|
||||
op.create_index("ix_asset_locations_provider", "asset_locations", ["provider"])
|
||||
|
||||
# Tags vocabulary for models
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
@ -143,13 +157,18 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_asset_locations_provider", table_name="asset_locations")
|
||||
op.drop_index("ix_asset_locations_hash", table_name="asset_locations")
|
||||
op.drop_table("asset_locations")
|
||||
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_table("asset_locator_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
@ -159,6 +178,7 @@ def downgrade() -> None:
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
@ -166,6 +186,5 @@ def downgrade() -> None:
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
op.drop_index("ix_assets_backend_locator", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
|
||||
@ -105,7 +105,10 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
|
||||
@ -9,7 +9,7 @@ from pydantic import ValidationError
|
||||
import folder_paths
|
||||
|
||||
from .. import assets_manager, assets_scanner
|
||||
from . import schemas_in
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -20,6 +20,9 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||
if not hash_str or ":" not in hash_str:
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
algo, digest = hash_str.split(":", 1)
|
||||
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
exists = await assets_manager.asset_exists(asset_hash=hash_str)
|
||||
return web.Response(status=200 if exists else 404)
|
||||
|
||||
@ -69,7 +72,7 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = filename.replace('"', "'")
|
||||
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
@ -115,6 +118,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
user_metadata_raw: Optional[str] = None
|
||||
file_written = 0
|
||||
|
||||
tmp_path: Optional[str] = None
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
@ -173,6 +177,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
|
||||
if spec.tags[0] == "models" and spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(400, "INVALID_BODY", f"unknown models category '{spec.tags[1]}'")
|
||||
|
||||
try:
|
||||
@ -182,12 +188,14 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
client_filename=file_client_name,
|
||||
)
|
||||
return web.json_response(created.model_dump(mode="json"), status=201)
|
||||
except ValueError:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
try:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
|
||||
@ROUTES.put("/api/assets/{id}")
|
||||
@ -341,6 +349,7 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
states = assets_scanner.current_statuses()
|
||||
if root in {"models", "input", "output"}:
|
||||
states = [s for s in states.scans if s.root == root] # type: ignore
|
||||
states = schemas_out.AssetScanStatusResponse(scans=states)
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
|
||||
132
app/assets_fetcher.py
Normal file
132
app/assets_fetcher.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
import mimetypes
|
||||
import aiohttp
|
||||
|
||||
from .storage.hashing import blake3_hash_sync
|
||||
from .database.db import create_session
|
||||
from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash
|
||||
from .resolvers import resolve_asset
|
||||
from ._assets_helpers import resolve_destination_from_tags, ensure_within_base
|
||||
|
||||
_FETCH_LOCKS: dict[str, asyncio.Lock] = {}
|
||||
|
||||
|
||||
def _sanitize_filename(name: str) -> str:
|
||||
return os.path.basename((name or "").strip()) or "file"
|
||||
|
||||
|
||||
async def ensure_asset_cached(
|
||||
asset_hash: str,
|
||||
*,
|
||||
preferred_name: Optional[str] = None,
|
||||
tags_hint: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ensure there is a verified local file for `asset_hash` in the correct Comfy folder.
|
||||
Policy:
|
||||
- Resolver must provide valid tags (root and, for models, category).
|
||||
- If target path already exists:
|
||||
* if hash matches -> reuse & ingest
|
||||
* else -> remove and overwrite with the correct content
|
||||
"""
|
||||
lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock())
|
||||
async with lock:
|
||||
# 1) If we already have a state -> trust the path
|
||||
async with await create_session() as sess:
|
||||
state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash)
|
||||
if state and os.path.isfile(state.file_path):
|
||||
return state.file_path
|
||||
|
||||
# 2) Resolve remote location + placement hints (must include valid tags)
|
||||
res = await resolve_asset(asset_hash)
|
||||
if not res:
|
||||
raise FileNotFoundError(f"No resolver/locations for {asset_hash}")
|
||||
|
||||
placement_tags = tags_hint or res.tags
|
||||
if not placement_tags:
|
||||
raise ValueError(f"Resolver did not provide placement tags for {asset_hash}")
|
||||
|
||||
name_hint = res.filename or preferred_name or asset_hash.replace(":", "_")
|
||||
safe_name = _sanitize_filename(name_hint)
|
||||
|
||||
# 3) Map tags -> destination (strict: raises if invalid root or models category)
|
||||
base_dir, subdirs = resolve_destination_from_tags(placement_tags) # may raise
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
final_path = os.path.abspath(os.path.join(dest_dir, safe_name))
|
||||
ensure_within_base(final_path, base_dir)
|
||||
|
||||
# 4) If target path exists, try to reuse; else delete invalid cache
|
||||
if os.path.exists(final_path) and os.path.isfile(final_path):
|
||||
existing_digest = blake3_hash_sync(final_path)
|
||||
if f"blake3:{existing_digest}" == asset_hash:
|
||||
size_bytes = os.path.getsize(final_path)
|
||||
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
|
||||
async with await create_session() as sess:
|
||||
await ingest_fs_asset(
|
||||
sess,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=final_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=None,
|
||||
info_name=None,
|
||||
tags=(),
|
||||
)
|
||||
await sess.commit()
|
||||
return final_path
|
||||
else:
|
||||
# Invalid cache: remove before re-downloading
|
||||
os.remove(final_path)
|
||||
|
||||
# 5) Download to temp next to destination
|
||||
timeout = aiohttp.ClientTimeout(total=60 * 30)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.get(res.download_url, headers=dict(res.headers)) as resp:
|
||||
resp.raise_for_status()
|
||||
cl = resp.headers.get("Content-Length")
|
||||
if res.expected_size and cl and int(cl) != int(res.expected_size):
|
||||
raise ValueError("server Content-Length does not match expected size")
|
||||
with tempfile.NamedTemporaryFile("wb", delete=False, dir=dest_dir) as tmp:
|
||||
tmp_path = tmp.name
|
||||
async for chunk in resp.content.iter_chunked(8 * 1024 * 1024):
|
||||
if chunk:
|
||||
tmp.write(chunk)
|
||||
|
||||
# 6) Verify content hash
|
||||
digest = blake3_hash_sync(tmp_path)
|
||||
canonical = f"blake3:{digest}"
|
||||
if canonical != asset_hash:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
finally:
|
||||
raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}")
|
||||
|
||||
# 7) Atomically move into place (we already removed an invalid file if it existed)
|
||||
if os.path.exists(final_path):
|
||||
os.remove(final_path)
|
||||
os.replace(tmp_path, final_path)
|
||||
|
||||
# 8) Record identity + cache state (+ mime type)
|
||||
size_bytes = os.path.getsize(final_path)
|
||||
mtime_ns = getattr(os.stat(final_path), "st_mtime_ns", int(os.path.getmtime(final_path) * 1_000_000_000))
|
||||
mime_type = mimetypes.guess_type(safe_name, strict=False)[0]
|
||||
async with await create_session() as sess:
|
||||
await ingest_fs_asset(
|
||||
sess,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=final_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=mime_type,
|
||||
info_name=None,
|
||||
tags=(),
|
||||
)
|
||||
await sess.commit()
|
||||
|
||||
return final_path
|
||||
@ -26,7 +26,8 @@ from .database.services import (
|
||||
create_asset_info_for_existing_asset,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from ._assets_helpers import get_name_and_tags_from_asset_path, resolve_destination_from_tags, ensure_within_base
|
||||
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
|
||||
from .assets_fetcher import ensure_asset_cached
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
@ -46,17 +47,17 @@ def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) ->
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError:
|
||||
logging.exception("Cant parse '%s' as an asset file path.", file_path)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
||||
|
||||
Notes:
|
||||
- Uses absolute path as the canonical locator for the 'fs' backend.
|
||||
- Uses absolute path as the canonical locator for the cache backend.
|
||||
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
||||
- This function ensures the identity row and seeds mtime in asset_locator_state.
|
||||
- This function ensures the identity row and seeds mtime in asset_cache_state.
|
||||
"""
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||
@ -125,7 +126,7 @@ async def list_assets(
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
preview_url=f"/api/v1/assets/{info.id}/content",
|
||||
preview_url=f"/api/assets/{info.id}/content",
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
@ -143,12 +144,11 @@ async def resolve_asset_content_for_download(
|
||||
*, asset_info_id: int
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id.
|
||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
||||
Also touches last_access_time (only_if_newer).
|
||||
Ensures the local cache is present (uses resolver if needed).
|
||||
Raises:
|
||||
ValueError if AssetInfo not found
|
||||
NotImplementedError for unsupported backend
|
||||
FileNotFoundError if underlying file does not exist (fs backend)
|
||||
ValueError if AssetInfo cannot be found
|
||||
"""
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
|
||||
@ -156,21 +156,19 @@ async def resolve_asset_content_for_download(
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
|
||||
if asset.storage_backend != "fs":
|
||||
# Future: support http/s3/gcs/...
|
||||
raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet")
|
||||
|
||||
abs_path = os.path.abspath(asset.storage_locator)
|
||||
if not os.path.exists(abs_path):
|
||||
raise FileNotFoundError(abs_path)
|
||||
# Ensure cached (download if missing)
|
||||
preferred_name = info.name or info.asset_hash.split(":", 1)[-1]
|
||||
abs_path = await ensure_asset_cached(info.asset_hash, preferred_name=preferred_name, tags_hint=tag_names)
|
||||
|
||||
async with await create_session() as session:
|
||||
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
async def upload_asset_from_temp_path(
|
||||
@ -238,7 +236,7 @@ async def upload_asset_from_temp_path(
|
||||
added_by=None,
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result.get("asset_info_id")
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
@ -260,7 +258,7 @@ async def upload_asset_from_temp_path(
|
||||
preview_hash=info.preview_hash,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=True,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
|
||||
|
||||
@ -416,7 +414,7 @@ def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: Optional[str] , fallback: str) -> str:
|
||||
def _safe_filename(name: Optional[str], fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
|
||||
@ -147,9 +147,7 @@ async def _run_scan_for_root(root: RootType, prog: ScanProgress) -> None:
|
||||
prog.last_error = str(exc)
|
||||
finally:
|
||||
prog.finished_at = time.time()
|
||||
t = RUNNING_TASKS.get(root)
|
||||
if t and t.done():
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _scan_models(prog: ScanProgress) -> None:
|
||||
|
||||
@ -45,8 +45,6 @@ class Asset(Base):
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
refcount: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
@ -71,48 +69,71 @@ class Asset(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
locator_state: Mapped["AssetLocatorState | None"] = relationship(
|
||||
cache_state: Mapped["AssetCacheState | None"] = relationship(
|
||||
back_populates="asset",
|
||||
uselist=False,
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
locations: Mapped[list["AssetLocation"]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
Index("ix_assets_backend_locator", "storage_backend", "storage_locator"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset hash={self.hash[:12]} backend={self.storage_backend}>"
|
||||
return f"<Asset hash={self.hash[:12]}>"
|
||||
|
||||
|
||||
class AssetLocatorState(Base):
|
||||
__tablename__ = "asset_locator_state"
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
# For fs backends: nanosecond mtime; nullable if not applicable
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
# For HTTP/S3/GCS/Azure, etc.: optional validators
|
||||
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locator_state", uselist=False)
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_als_mtime_nonneg"),
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetLocatorState hash={self.asset_hash[:12]} mtime_ns={self.mtime_ns}>"
|
||||
return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetLocation(Base):
|
||||
__tablename__ = "asset_locations"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
|
||||
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
|
||||
expected_size_bytes: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
etag: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locations")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
|
||||
Index("ix_asset_locations_hash", "asset_hash"),
|
||||
Index("ix_asset_locations_provider", "provider"),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
@ -220,7 +241,7 @@ class AssetInfoTag(Base):
|
||||
Integer, ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(128), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||
@ -240,7 +261,7 @@ class AssetInfoTag(Base):
|
||||
class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list["AssetInfoTag"]] = relationship(
|
||||
|
||||
@ -12,7 +12,7 @@ from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
||||
from .timeutil import utcnow
|
||||
from .._assets_helpers import normalize_tags
|
||||
|
||||
@ -38,30 +38,24 @@ async def check_fs_asset_exists_quick(
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns 'True' if there is already an Asset present whose canonical locator matches this absolute path,
|
||||
Returns 'True' if there is already AssetCacheState record that matches this absolute path,
|
||||
AND (if provided) mtime_ns matches stored locator-state,
|
||||
AND (if provided) size_bytes matches verified size when known.
|
||||
"""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = select(sa.literal(True)).select_from(Asset)
|
||||
stmt = select(sa.literal(True)).select_from(AssetCacheState).join(
|
||||
Asset, Asset.hash == AssetCacheState.asset_hash
|
||||
).where(AssetCacheState.file_path == locator).limit(1)
|
||||
|
||||
conditions = [
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
]
|
||||
|
||||
# If size_bytes provided require equality when the asset has a verified (non-zero) size.
|
||||
# If verified size is 0 (unknown), we don't force equality.
|
||||
if size_bytes is not None:
|
||||
conditions.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
# If mtime_ns provided require the locator-state to exist and match.
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
stmt = stmt.join(AssetLocatorState, AssetLocatorState.asset_hash == Asset.hash)
|
||||
conditions.append(AssetLocatorState.mtime_ns == int(mtime_ns))
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
stmt = stmt.where(*conditions).limit(1)
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
@ -85,11 +79,11 @@ async def ingest_fs_asset(
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Creates or updates Asset record for a local (fs) asset.
|
||||
Upsert Asset identity row + cache state pointing at local file.
|
||||
|
||||
Always:
|
||||
- Insert Asset if missing; else update size_bytes (and updated_at) if different.
|
||||
- Insert AssetLocatorState if missing; else update mtime_ns if different.
|
||||
- Insert AssetCacheState if missing; else update mtime_ns if different.
|
||||
|
||||
Optionally (when info_name is provided):
|
||||
- Create an AssetInfo (no refcount changes).
|
||||
@ -126,8 +120,6 @@ async def ingest_fs_asset(
|
||||
size_bytes=int(size_bytes),
|
||||
mime_type=mime_type,
|
||||
refcount=0,
|
||||
storage_backend="fs",
|
||||
storage_locator=locator,
|
||||
created_at=datetime_now,
|
||||
updated_at=datetime_now,
|
||||
)
|
||||
@ -145,21 +137,19 @@ async def ingest_fs_asset(
|
||||
if mime_type and existing.mime_type != mime_type:
|
||||
existing.mime_type = mime_type
|
||||
changed = True
|
||||
if existing.storage_locator != locator:
|
||||
existing.storage_locator = locator
|
||||
changed = True
|
||||
if changed:
|
||||
existing.updated_at = datetime_now
|
||||
out["asset_updated"] = True
|
||||
else:
|
||||
logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash)
|
||||
|
||||
# ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ----
|
||||
# ---- Step 2: INSERT/UPDATE AssetCacheState (mtime_ns, file_path) ----
|
||||
with contextlib.suppress(IntegrityError):
|
||||
async with session.begin_nested():
|
||||
session.add(
|
||||
AssetLocatorState(
|
||||
AssetCacheState(
|
||||
asset_hash=asset_hash,
|
||||
file_path=locator,
|
||||
mtime_ns=int(mtime_ns),
|
||||
)
|
||||
)
|
||||
@ -167,11 +157,17 @@ async def ingest_fs_asset(
|
||||
out["state_created"] = True
|
||||
|
||||
if not out["state_created"]:
|
||||
state = await session.get(AssetLocatorState, asset_hash)
|
||||
state = await session.get(AssetCacheState, asset_hash)
|
||||
if state is not None:
|
||||
desired_mtime = int(mtime_ns)
|
||||
if state.mtime_ns != desired_mtime:
|
||||
state.mtime_ns = desired_mtime
|
||||
changed = False
|
||||
if state.file_path != locator:
|
||||
state.file_path = locator
|
||||
changed = True
|
||||
if state.mtime_ns != int(mtime_ns):
|
||||
state.mtime_ns = int(mtime_ns)
|
||||
changed = True
|
||||
if changed:
|
||||
await session.flush()
|
||||
out["state_updated"] = True
|
||||
else:
|
||||
logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash)
|
||||
@ -278,11 +274,10 @@ async def touch_asset_infos_by_fs_path(
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(Asset)
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
Asset.hash == AssetInfo.asset_hash,
|
||||
Asset.storage_backend == "fs",
|
||||
Asset.storage_locator == locator,
|
||||
AssetCacheState.asset_hash == AssetInfo.asset_hash,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -337,13 +332,6 @@ async def list_asset_infos_page(
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
# Clamp
|
||||
if limit <= 0:
|
||||
limit = 1
|
||||
if limit > 100:
|
||||
limit = 100
|
||||
if offset < 0:
|
||||
offset = 0
|
||||
|
||||
# Build base query
|
||||
base = (
|
||||
@ -419,6 +407,66 @@ async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: in
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
|
||||
return await session.get(AssetCacheState, asset_hash)
|
||||
|
||||
|
||||
async def list_asset_locations(
|
||||
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
|
||||
) -> list[AssetLocation]:
|
||||
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
|
||||
if provider:
|
||||
stmt = stmt.where(AssetLocation.provider == provider)
|
||||
return (await session.execute(stmt)).scalars().all()
|
||||
|
||||
|
||||
async def upsert_asset_location(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
provider: str,
|
||||
locator: str,
|
||||
expected_size_bytes: Optional[int] = None,
|
||||
etag: Optional[str] = None,
|
||||
last_modified: Optional[str] = None,
|
||||
) -> AssetLocation:
|
||||
loc = (
|
||||
await session.execute(
|
||||
select(AssetLocation).where(
|
||||
AssetLocation.asset_hash == asset_hash,
|
||||
AssetLocation.provider == provider,
|
||||
AssetLocation.locator == locator,
|
||||
).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if loc:
|
||||
changed = False
|
||||
if expected_size_bytes is not None and loc.expected_size_bytes != expected_size_bytes:
|
||||
loc.expected_size_bytes = expected_size_bytes
|
||||
changed = True
|
||||
if etag is not None and loc.etag != etag:
|
||||
loc.etag = etag
|
||||
changed = True
|
||||
if last_modified is not None and loc.last_modified != last_modified:
|
||||
loc.last_modified = last_modified
|
||||
changed = True
|
||||
if changed:
|
||||
await session.flush()
|
||||
return loc
|
||||
|
||||
loc = AssetLocation(
|
||||
asset_hash=asset_hash,
|
||||
provider=provider,
|
||||
locator=locator,
|
||||
expected_size_bytes=expected_size_bytes,
|
||||
etag=etag,
|
||||
last_modified=last_modified,
|
||||
)
|
||||
session.add(loc)
|
||||
await session.flush()
|
||||
return loc
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@ -925,7 +973,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": value})
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
@ -943,7 +992,8 @@ def _project_kv(key: str, value: Any) -> list[dict]:
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
rows.append({"key": key, "ordinal": i, "val_num": x})
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
|
||||
35
app/resolvers/__init__.py
Normal file
35
app/resolvers/__init__.py
Normal file
@ -0,0 +1,35 @@
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol, Optional, Mapping
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolveResult:
|
||||
provider: str # e.g., "gcs"
|
||||
download_url: str # fully-qualified URL to fetch bytes
|
||||
headers: Mapping[str, str] # optional auth headers etc
|
||||
expected_size: Optional[int] = None
|
||||
tags: Optional[list[str]] = None # e.g. ["models","vae","subdir"]
|
||||
filename: Optional[str] = None # preferred basename
|
||||
|
||||
class AssetResolver(Protocol):
|
||||
provider: str
|
||||
async def resolve(self, asset_hash: str) -> Optional[ResolveResult]: ...
|
||||
|
||||
|
||||
_REGISTRY: list[AssetResolver] = []
|
||||
|
||||
|
||||
def register_resolver(resolver: AssetResolver) -> None:
|
||||
"""Append Resolver with simple de-dup per provider."""
|
||||
global _REGISTRY
|
||||
_REGISTRY = [r for r in _REGISTRY if r.provider != resolver.provider] + [resolver]
|
||||
|
||||
|
||||
async def resolve_asset(asset_hash: str) -> Optional[ResolveResult]:
|
||||
for r in _REGISTRY:
|
||||
with contextlib.suppress(Exception): # For Resolver failure we just try the next one
|
||||
res = await r.resolve(asset_hash)
|
||||
if res:
|
||||
return res
|
||||
return None
|
||||
Loading…
x
Reference in New Issue
Block a user