mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 00:14:31 +08:00
global refactoring; add support for Assets without the computed hash
This commit is contained in:
parent
934377ac1e
commit
bb9ed04758
@ -16,33 +16,44 @@ depends_on = None
|
|||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ASSETS: content identity (deduplicated by hash)
|
# ASSETS: content identity
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"assets",
|
"assets",
|
||||||
sa.Column("hash", sa.String(length=256), primary_key=True),
|
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||||
|
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||||
)
|
)
|
||||||
|
if op.get_bind().dialect.name == "postgresql":
|
||||||
|
op.create_index(
|
||||||
|
"uq_assets_hash_not_null",
|
||||||
|
"assets",
|
||||||
|
["hash"],
|
||||||
|
unique=True,
|
||||||
|
postgresql_where=sa.text("hash IS NOT NULL"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||||
|
|
||||||
# ASSETS_INFO: user-visible references (mutable metadata)
|
# ASSETS_INFO: user-visible references
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"assets_info",
|
"assets_info",
|
||||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||||
sa.Column("name", sa.String(length=512), nullable=False),
|
sa.Column("name", sa.String(length=512), nullable=False),
|
||||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
|
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||||
sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
|
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||||
sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
|
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||||
)
|
)
|
||||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||||
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
|
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||||
@ -69,18 +80,19 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||||
|
|
||||||
# ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset
|
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"asset_cache_state",
|
"asset_cache_state",
|
||||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
|
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||||
)
|
)
|
||||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||||
op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"])
|
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||||
|
|
||||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||||
op.create_table(
|
op.create_table(
|
||||||
@ -99,7 +111,7 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||||
|
|
||||||
# Tags vocabulary for models
|
# Tags vocabulary
|
||||||
tags_table = sa.table(
|
tags_table = sa.table(
|
||||||
"tags",
|
"tags",
|
||||||
sa.column("name", sa.String(length=512)),
|
sa.column("name", sa.String(length=512)),
|
||||||
@ -108,12 +120,10 @@ def upgrade() -> None:
|
|||||||
op.bulk_insert(
|
op.bulk_insert(
|
||||||
tags_table,
|
tags_table,
|
||||||
[
|
[
|
||||||
# Root folder tags
|
|
||||||
{"name": "models", "tag_type": "system"},
|
{"name": "models", "tag_type": "system"},
|
||||||
{"name": "input", "tag_type": "system"},
|
{"name": "input", "tag_type": "system"},
|
||||||
{"name": "output", "tag_type": "system"},
|
{"name": "output", "tag_type": "system"},
|
||||||
|
|
||||||
# Core tags
|
|
||||||
{"name": "configs", "tag_type": "system"},
|
{"name": "configs", "tag_type": "system"},
|
||||||
{"name": "checkpoints", "tag_type": "system"},
|
{"name": "checkpoints", "tag_type": "system"},
|
||||||
{"name": "loras", "tag_type": "system"},
|
{"name": "loras", "tag_type": "system"},
|
||||||
@ -132,12 +142,11 @@ def upgrade() -> None:
|
|||||||
{"name": "photomaker", "tag_type": "system"},
|
{"name": "photomaker", "tag_type": "system"},
|
||||||
{"name": "classifiers", "tag_type": "system"},
|
{"name": "classifiers", "tag_type": "system"},
|
||||||
|
|
||||||
# Extra basic tags
|
|
||||||
{"name": "encoder", "tag_type": "system"},
|
{"name": "encoder", "tag_type": "system"},
|
||||||
{"name": "decoder", "tag_type": "system"},
|
{"name": "decoder", "tag_type": "system"},
|
||||||
|
|
||||||
# Special tags
|
|
||||||
{"name": "missing", "tag_type": "system"},
|
{"name": "missing", "tag_type": "system"},
|
||||||
|
{"name": "rescan", "tag_type": "system"},
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -149,8 +158,9 @@ def downgrade() -> None:
|
|||||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||||
op.drop_table("asset_info_meta")
|
op.drop_table("asset_info_meta")
|
||||||
|
|
||||||
op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state")
|
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||||
|
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||||
op.drop_table("asset_cache_state")
|
op.drop_table("asset_cache_state")
|
||||||
|
|
||||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||||
@ -160,14 +170,18 @@ def downgrade() -> None:
|
|||||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||||
op.drop_table("tags")
|
op.drop_table("tags")
|
||||||
|
|
||||||
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
|
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
|
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||||
op.drop_table("assets_info")
|
op.drop_table("assets_info")
|
||||||
|
|
||||||
|
if op.get_bind().dialect.name == "postgresql":
|
||||||
|
op.drop_index("uq_assets_hash_not_null", table_name="assets")
|
||||||
|
else:
|
||||||
|
op.drop_index("uq_assets_hash", table_name="assets")
|
||||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||||
op.drop_table("assets")
|
op.drop_table("assets")
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
|
from .assets_scanner import sync_seed_assets
|
||||||
from .database.db import init_db_engine
|
from .database.db import init_db_engine
|
||||||
from .assets_scanner import start_background_assets_scan
|
from .api.assets_routes import register_assets_system
|
||||||
|
|
||||||
|
__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"]
|
||||||
__all__ = ["init_db_engine", "start_background_assets_scan"]
|
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, Sequence
|
from typing import Literal, Optional, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from .database.models import AssetInfo
|
from .api import schemas_in
|
||||||
|
|
||||||
|
|
||||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||||
@ -139,14 +140,6 @@ def ensure_within_base(candidate: str, base: str) -> None:
|
|||||||
raise ValueError("invalid destination path")
|
raise ValueError("invalid destination path")
|
||||||
|
|
||||||
|
|
||||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|
||||||
"""Build owner visibility predicate for reads."""
|
|
||||||
owner_id = (owner_id or "").strip()
|
|
||||||
if owner_id == "":
|
|
||||||
return AssetInfo.owner_id == ""
|
|
||||||
return AssetInfo.owner_id.in_(["", owner_id])
|
|
||||||
|
|
||||||
|
|
||||||
def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Return the model's path relative to the last well-known folder (the model category),
|
Return the model's path relative to the last well-known folder (the model category),
|
||||||
@ -172,3 +165,61 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
|||||||
return None
|
return None
|
||||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||||
return "/".join(inside) # normalize to POSIX style for portability
|
return "/".join(inside) # normalize to POSIX style for portability
|
||||||
|
|
||||||
|
|
||||||
|
def list_tree(base_dir: str) -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
base_abs = os.path.abspath(base_dir)
|
||||||
|
if not os.path.isdir(base_abs):
|
||||||
|
return out
|
||||||
|
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||||
|
for name in filenames:
|
||||||
|
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
|
||||||
|
if root == "models":
|
||||||
|
bases: list[str] = []
|
||||||
|
for _bucket, paths in get_comfy_models_folders():
|
||||||
|
bases.extend(paths)
|
||||||
|
return [os.path.abspath(p) for p in bases]
|
||||||
|
if root == "input":
|
||||||
|
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||||
|
if root == "output":
|
||||||
|
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||||
|
if ts is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def new_scan_id(root: schemas_in.RootType) -> str:
|
||||||
|
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
|
||||||
|
def collect_models_files() -> list[str]:
|
||||||
|
out: list[str] = []
|
||||||
|
for folder_name, bases in get_comfy_models_folders():
|
||||||
|
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||||
|
for rel_path in rel_files:
|
||||||
|
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||||
|
if not abs_path:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(abs_path)
|
||||||
|
allowed = False
|
||||||
|
for b in bases:
|
||||||
|
base_abs = os.path.abspath(b)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||||
|
allowed = True
|
||||||
|
break
|
||||||
|
if allowed:
|
||||||
|
out.append(abs_path)
|
||||||
|
return out
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
@ -12,7 +12,6 @@ import folder_paths
|
|||||||
from .. import assets_manager, assets_scanner, user_manager
|
from .. import assets_manager, assets_scanner, user_manager
|
||||||
from . import schemas_in, schemas_out
|
from . import schemas_in, schemas_out
|
||||||
|
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
UserManager: Optional[user_manager.UserManager] = None
|
UserManager: Optional[user_manager.UserManager] = None
|
||||||
|
|
||||||
@ -272,6 +271,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
temp_path=tmp_path,
|
temp_path=tmp_path,
|
||||||
client_filename=file_client_name,
|
client_filename=file_client_name,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
expected_asset_hash=spec.hash,
|
||||||
)
|
)
|
||||||
status = 201 if created.created_new else 200
|
status = 201 if created.created_new else 200
|
||||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||||
@ -332,6 +332,29 @@ async def update_asset(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||||
|
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await assets_manager.set_asset_preview(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
preview_asset_id=body.preview_id,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except (PermissionError, ValueError) as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
async def delete_asset(request: web.Request) -> web.Response:
|
async def delete_asset(request: web.Request) -> web.Response:
|
||||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
|||||||
@ -1,7 +1,15 @@
|
|||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from typing import Any, Optional, Literal
|
from pydantic import (
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint
|
BaseModel,
|
||||||
|
ConfigDict,
|
||||||
|
Field,
|
||||||
|
conint,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
@ -148,30 +156,12 @@ class TagsRemove(TagsAdd):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ScheduleAssetScanBody(BaseModel):
|
RootType = Literal["models", "input", "output"]
|
||||||
roots: list[Literal["models","input","output"]] = Field(default_factory=list)
|
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||||
|
|
||||||
@field_validator("roots", mode="before")
|
|
||||||
@classmethod
|
class ScheduleAssetScanBody(BaseModel):
|
||||||
def _normalize_roots(cls, v):
|
roots: list[RootType] = Field(..., min_length=1)
|
||||||
if v is None:
|
|
||||||
return []
|
|
||||||
if isinstance(v, str):
|
|
||||||
items = [x.strip().lower() for x in v.split(",")]
|
|
||||||
elif isinstance(v, list):
|
|
||||||
items = []
|
|
||||||
for x in v:
|
|
||||||
if isinstance(x, str):
|
|
||||||
items.extend([p.strip().lower() for p in x.split(",")])
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
out = []
|
|
||||||
seen = set()
|
|
||||||
for r in items:
|
|
||||||
if r in {"models","input","output"} and r not in seen:
|
|
||||||
out.append(r)
|
|
||||||
seen.add(r)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class UploadAssetSpec(BaseModel):
|
class UploadAssetSpec(BaseModel):
|
||||||
@ -281,3 +271,22 @@ class UploadAssetSpec(BaseModel):
|
|||||||
if len(self.tags) < 2:
|
if len(self.tags) < 2:
|
||||||
raise ValueError("models uploads require a category tag as the second tag")
|
raise ValueError("models uploads require a category tag as the second tag")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class SetPreviewBody(BaseModel):
|
||||||
|
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||||
|
preview_id: Optional[str] = None
|
||||||
|
|
||||||
|
@field_validator("preview_id", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _norm_uuid(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
s = str(v).strip()
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
uuid.UUID(s)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("preview_id must be a UUID")
|
||||||
|
return s
|
||||||
|
|||||||
@ -1,12 +1,13 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
class AssetSummary(BaseModel):
|
class AssetSummary(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str
|
asset_hash: Optional[str]
|
||||||
size: Optional[int] = None
|
size: Optional[int] = None
|
||||||
mime_type: Optional[str] = None
|
mime_type: Optional[str] = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
@ -31,7 +32,7 @@ class AssetsList(BaseModel):
|
|||||||
class AssetUpdated(BaseModel):
|
class AssetUpdated(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str
|
asset_hash: Optional[str]
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
updated_at: Optional[datetime] = None
|
updated_at: Optional[datetime] = None
|
||||||
@ -46,12 +47,12 @@ class AssetUpdated(BaseModel):
|
|||||||
class AssetDetail(BaseModel):
|
class AssetDetail(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str
|
asset_hash: Optional[str]
|
||||||
size: Optional[int] = None
|
size: Optional[int] = None
|
||||||
mime_type: Optional[str] = None
|
mime_type: Optional[str] = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
preview_hash: Optional[str] = None
|
preview_id: Optional[str] = None
|
||||||
created_at: Optional[datetime] = None
|
created_at: Optional[datetime] = None
|
||||||
last_access_time: Optional[datetime] = None
|
last_access_time: Optional[datetime] = None
|
||||||
|
|
||||||
@ -95,7 +96,6 @@ class TagsRemove(BaseModel):
|
|||||||
class AssetScanError(BaseModel):
|
class AssetScanError(BaseModel):
|
||||||
path: str
|
path: str
|
||||||
message: str
|
message: str
|
||||||
phase: Literal["fast", "slow"]
|
|
||||||
at: Optional[str] = Field(None, description="ISO timestamp")
|
at: Optional[str] = Field(None, description="ISO timestamp")
|
||||||
|
|
||||||
|
|
||||||
@ -108,8 +108,6 @@ class AssetScanStatus(BaseModel):
|
|||||||
finished_at: Optional[str] = None
|
finished_at: Optional[str] = None
|
||||||
discovered: int = 0
|
discovered: int = 0
|
||||||
processed: int = 0
|
processed: int = 0
|
||||||
slow_queue_total: int = 0
|
|
||||||
slow_queue_finished: int = 0
|
|
||||||
file_errors: list[AssetScanError] = Field(default_factory=list)
|
file_errors: list[AssetScanError] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,38 +4,39 @@ import mimetypes
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from comfy_api.internal import async_to_sync
|
from comfy_api.internal import async_to_sync
|
||||||
|
|
||||||
from .database.db import create_session
|
|
||||||
from .storage import hashing
|
|
||||||
from .database.services import (
|
|
||||||
check_fs_asset_exists_quick,
|
|
||||||
ingest_fs_asset,
|
|
||||||
touch_asset_infos_by_fs_path,
|
|
||||||
list_asset_infos_page,
|
|
||||||
update_asset_info_full,
|
|
||||||
get_asset_tags,
|
|
||||||
list_tags_with_usage,
|
|
||||||
add_tags_to_asset_info,
|
|
||||||
remove_tags_from_asset_info,
|
|
||||||
fetch_asset_info_and_asset,
|
|
||||||
touch_asset_info_by_id,
|
|
||||||
delete_asset_info_by_id,
|
|
||||||
asset_exists_by_hash,
|
|
||||||
get_asset_by_hash,
|
|
||||||
create_asset_info_for_existing_asset,
|
|
||||||
fetch_asset_info_asset_and_tags,
|
|
||||||
get_asset_info_by_id,
|
|
||||||
list_cache_states_by_asset_hash,
|
|
||||||
asset_info_exists_for_hash,
|
|
||||||
)
|
|
||||||
from .api import schemas_in, schemas_out
|
|
||||||
from ._assets_helpers import (
|
from ._assets_helpers import (
|
||||||
get_name_and_tags_from_asset_path,
|
|
||||||
ensure_within_base,
|
ensure_within_base,
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
resolve_destination_from_tags,
|
resolve_destination_from_tags,
|
||||||
)
|
)
|
||||||
|
from .api import schemas_in, schemas_out
|
||||||
|
from .database.db import create_session
|
||||||
|
from .database.models import Asset
|
||||||
|
from .database.services import (
|
||||||
|
add_tags_to_asset_info,
|
||||||
|
asset_exists_by_hash,
|
||||||
|
asset_info_exists_for_asset_id,
|
||||||
|
check_fs_asset_exists_quick,
|
||||||
|
create_asset_info_for_existing_asset,
|
||||||
|
delete_asset_info_by_id,
|
||||||
|
fetch_asset_info_and_asset,
|
||||||
|
fetch_asset_info_asset_and_tags,
|
||||||
|
get_asset_by_hash,
|
||||||
|
get_asset_info_by_id,
|
||||||
|
get_asset_tags,
|
||||||
|
ingest_fs_asset,
|
||||||
|
list_asset_infos_page,
|
||||||
|
list_cache_states_by_asset_id,
|
||||||
|
list_tags_with_usage,
|
||||||
|
remove_tags_from_asset_info,
|
||||||
|
set_asset_info_preview,
|
||||||
|
touch_asset_info_by_id,
|
||||||
|
touch_asset_infos_by_fs_path,
|
||||||
|
update_asset_info_full,
|
||||||
|
)
|
||||||
|
from .storage import hashing
|
||||||
|
|
||||||
|
|
||||||
async def asset_exists(*, asset_hash: str) -> bool:
|
async def asset_exists(*, asset_hash: str) -> bool:
|
||||||
@ -44,29 +45,21 @@ async def asset_exists(*, asset_hash: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||||
if not args.enable_model_processing:
|
if tags is None:
|
||||||
if tags is None:
|
tags = []
|
||||||
tags = []
|
try:
|
||||||
try:
|
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
add_local_asset,
|
||||||
add_local_asset,
|
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
file_name=asset_name,
|
||||||
file_name=asset_name,
|
file_path=file_path,
|
||||||
file_path=file_path,
|
)
|
||||||
)
|
except ValueError as e:
|
||||||
except ValueError as e:
|
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||||
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Uses absolute path as the canonical locator for the cache backend.
|
|
||||||
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
|
||||||
- This function ensures the identity row and seeds mtime in asset_cache_state.
|
|
||||||
"""
|
|
||||||
abs_path = os.path.abspath(file_path)
|
abs_path = os.path.abspath(file_path)
|
||||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||||
if not size_bytes:
|
if not size_bytes:
|
||||||
@ -132,7 +125,7 @@ async def list_assets(
|
|||||||
schemas_out.AssetSummary(
|
schemas_out.AssetSummary(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=asset.hash if asset else None,
|
||||||
size=int(asset.size_bytes) if asset else None,
|
size=int(asset.size_bytes) if asset else None,
|
||||||
mime_type=asset.mime_type if asset else None,
|
mime_type=asset.mime_type if asset else None,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@ -156,16 +149,17 @@ async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.As
|
|||||||
if not res:
|
if not res:
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
info, asset, tag_names = res
|
info, asset, tag_names = res
|
||||||
|
preview_id = info.preview_id
|
||||||
|
|
||||||
return schemas_out.AssetDetail(
|
return schemas_out.AssetDetail(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=asset.hash if asset else None,
|
||||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||||
mime_type=asset.mime_type if asset else None,
|
mime_type=asset.mime_type if asset else None,
|
||||||
tags=tag_names,
|
tags=tag_names,
|
||||||
preview_hash=info.preview_hash,
|
|
||||||
user_metadata=info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=preview_id,
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
)
|
)
|
||||||
@ -176,20 +170,13 @@ async def resolve_asset_content_for_download(
|
|||||||
asset_info_id: str,
|
asset_info_id: str,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""
|
|
||||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
|
||||||
Also touches last_access_time (only_if_newer).
|
|
||||||
Raises:
|
|
||||||
ValueError if AssetInfo cannot be found
|
|
||||||
FileNotFoundError if file for Asset cannot be found
|
|
||||||
"""
|
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not pair:
|
if not pair:
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
info, asset = pair
|
info, asset = pair
|
||||||
states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash)
|
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||||
abs_path = ""
|
abs_path = ""
|
||||||
for s in states:
|
for s in states:
|
||||||
if s and s.file_path and os.path.isfile(s.file_path):
|
if s and s.file_path and os.path.isfile(s.file_path):
|
||||||
@ -214,16 +201,6 @@ async def upload_asset_from_temp_path(
|
|||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
expected_asset_hash: Optional[str] = None,
|
expected_asset_hash: Optional[str] = None,
|
||||||
) -> schemas_out.AssetCreated:
|
) -> schemas_out.AssetCreated:
|
||||||
"""
|
|
||||||
Finalize an uploaded temp file:
|
|
||||||
- compute blake3 hash
|
|
||||||
- if expected_asset_hash provided, verify equality (400 on mismatch at caller)
|
|
||||||
- if an Asset with the same hash exists: discard temp, create AssetInfo only (no write)
|
|
||||||
- else resolve destination from tags and atomically move into place
|
|
||||||
- ingest into DB (assets, locator state, asset_info + tags)
|
|
||||||
Returns a populated AssetCreated payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
digest = await hashing.blake3_hash(temp_path)
|
digest = await hashing.blake3_hash(temp_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -233,7 +210,6 @@ async def upload_asset_from_temp_path(
|
|||||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||||
raise ValueError("HASH_MISMATCH")
|
raise ValueError("HASH_MISMATCH")
|
||||||
|
|
||||||
# Fast path: content already known --> no writes, just create a reference
|
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
if existing is not None:
|
if existing is not None:
|
||||||
@ -257,43 +233,37 @@ async def upload_asset_from_temp_path(
|
|||||||
return schemas_out.AssetCreated(
|
return schemas_out.AssetCreated(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=existing.hash,
|
||||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||||
mime_type=existing.mime_type,
|
mime_type=existing.mime_type,
|
||||||
tags=tag_names,
|
tags=tag_names,
|
||||||
user_metadata=info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
preview_hash=info.preview_hash,
|
preview_id=info.preview_id,
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
created_new=False,
|
created_new=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve destination (only for truly new content)
|
|
||||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||||
os.makedirs(dest_dir, exist_ok=True)
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
|
||||||
# Decide filename
|
|
||||||
desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||||
dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name))
|
dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name))
|
||||||
ensure_within_base(dest_abs, base_dir)
|
ensure_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
# Content type based on final name
|
|
||||||
content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream"
|
content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream"
|
||||||
|
|
||||||
# Atomic move into place
|
|
||||||
try:
|
try:
|
||||||
os.replace(temp_path, dest_abs)
|
os.replace(temp_path, dest_abs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||||
|
|
||||||
# Stat final file
|
|
||||||
try:
|
try:
|
||||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||||
|
|
||||||
# Ingest + build response
|
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
result = await ingest_fs_asset(
|
result = await ingest_fs_asset(
|
||||||
session,
|
session,
|
||||||
@ -304,7 +274,7 @@ async def upload_asset_from_temp_path(
|
|||||||
mime_type=content_type,
|
mime_type=content_type,
|
||||||
info_name=os.path.basename(dest_abs),
|
info_name=os.path.basename(dest_abs),
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
preview_hash=None,
|
preview_id=None,
|
||||||
user_metadata=spec.user_metadata or {},
|
user_metadata=spec.user_metadata or {},
|
||||||
tags=spec.tags,
|
tags=spec.tags,
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
@ -324,12 +294,12 @@ async def upload_asset_from_temp_path(
|
|||||||
return schemas_out.AssetCreated(
|
return schemas_out.AssetCreated(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=asset.hash,
|
||||||
size=int(asset.size_bytes),
|
size=int(asset.size_bytes),
|
||||||
mime_type=asset.mime_type,
|
mime_type=asset.mime_type,
|
||||||
tags=tag_names,
|
tags=tag_names,
|
||||||
user_metadata=info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
preview_hash=info.preview_hash,
|
preview_id=info.preview_id,
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
created_new=result["asset_created"],
|
created_new=result["asset_created"],
|
||||||
@ -367,38 +337,74 @@ async def update_asset(
|
|||||||
return schemas_out.AssetUpdated(
|
return schemas_out.AssetUpdated(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=info.asset.hash if info.asset else None,
|
||||||
tags=tag_names,
|
tags=tag_names,
|
||||||
user_metadata=info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
updated_at=info.updated_at,
|
updated_at=info.updated_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
async def set_asset_preview(
|
||||||
"""Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default),
|
*,
|
||||||
delete the Asset row as well and remove all cached files recorded for that asset_hash.
|
asset_info_id: str,
|
||||||
"""
|
preview_asset_id: Optional[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetDetail:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
asset_hash = info_row.asset_hash if info_row else None
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
await set_asset_info_preview(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
preview_asset_id=preview_asset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not res:
|
||||||
|
raise RuntimeError("State changed during preview update")
|
||||||
|
info, asset, tags = res
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return schemas_out.AssetDetail(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash if asset else None,
|
||||||
|
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||||
|
mime_type=asset.mime_type if asset else None,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||||
|
async with await create_session() as session:
|
||||||
|
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
asset_id = info_row.asset_id if info_row else None
|
||||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not delete_content_if_orphan or not asset_hash:
|
if not delete_content_if_orphan or not asset_id:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash)
|
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||||
if still_exists:
|
if still_exists:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash)
|
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||||
|
|
||||||
asset_row = await get_asset_by_hash(session, asset_hash=asset_hash)
|
asset_row = await session.get(Asset, asset_id)
|
||||||
if asset_row is not None:
|
if asset_row is not None:
|
||||||
await session.delete(asset_row)
|
await session.delete(asset_row)
|
||||||
|
|
||||||
@ -439,12 +445,12 @@ async def create_asset_from_hash(
|
|||||||
return schemas_out.AssetCreated(
|
return schemas_out.AssetCreated(
|
||||||
id=info.id,
|
id=info.id,
|
||||||
name=info.name,
|
name=info.name,
|
||||||
asset_hash=info.asset_hash,
|
asset_hash=asset.hash,
|
||||||
size=int(asset.size_bytes),
|
size=int(asset.size_bytes),
|
||||||
mime_type=asset.mime_type,
|
mime_type=asset.mime_type,
|
||||||
tags=tag_names,
|
tags=tag_names,
|
||||||
user_metadata=info.user_metadata or {},
|
user_metadata=info.user_metadata or {},
|
||||||
preview_hash=info.preview_hash,
|
preview_id=info.preview_id,
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
created_new=False,
|
created_new=False,
|
||||||
|
|||||||
@ -1,52 +1,55 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime, timezone
|
from typing import Literal, Optional
|
||||||
from typing import Callable, Literal, Optional, Sequence
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from . import assets_manager
|
from ._assets_helpers import (
|
||||||
from .api import schemas_out
|
collect_models_files,
|
||||||
from ._assets_helpers import get_comfy_models_folders
|
get_comfy_models_folders,
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
|
list_tree,
|
||||||
|
new_scan_id,
|
||||||
|
prefixes_for_root,
|
||||||
|
ts_to_iso,
|
||||||
|
)
|
||||||
|
from .api import schemas_in, schemas_out
|
||||||
from .database.db import create_session
|
from .database.db import create_session
|
||||||
|
from .database.helpers import (
|
||||||
|
add_missing_tag_for_asset_id,
|
||||||
|
remove_missing_tag_for_asset_id,
|
||||||
|
)
|
||||||
|
from .database.models import Asset, AssetCacheState, AssetInfo
|
||||||
from .database.services import (
|
from .database.services import (
|
||||||
check_fs_asset_exists_quick,
|
compute_hash_and_dedup_for_cache_state,
|
||||||
|
ensure_seed_for_path,
|
||||||
|
list_cache_states_by_asset_id,
|
||||||
list_cache_states_with_asset_under_prefixes,
|
list_cache_states_with_asset_under_prefixes,
|
||||||
add_missing_tag_for_asset_hash,
|
list_unhashed_candidates_under_prefixes,
|
||||||
remove_missing_tag_for_asset_hash,
|
list_verify_candidates_under_prefixes,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
RootType = Literal["models", "input", "output"]
|
|
||||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
|
||||||
|
|
||||||
SLOW_HASH_CONCURRENCY = 1
|
SLOW_HASH_CONCURRENCY = 1
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScanProgress:
|
class ScanProgress:
|
||||||
scan_id: str
|
scan_id: str
|
||||||
root: RootType
|
root: schemas_in.RootType
|
||||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||||
started_at: Optional[float] = None
|
started_at: Optional[float] = None
|
||||||
finished_at: Optional[float] = None
|
finished_at: Optional[float] = None
|
||||||
|
|
||||||
discovered: int = 0
|
discovered: int = 0
|
||||||
processed: int = 0
|
processed: int = 0
|
||||||
slow_queue_total: int = 0
|
file_errors: list[dict] = field(default_factory=list)
|
||||||
slow_queue_finished: int = 0
|
|
||||||
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
|
|
||||||
|
|
||||||
# Internal diagnostics for logs
|
|
||||||
_fast_total_seen: int = 0
|
|
||||||
_fast_clean: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -56,18 +59,14 @@ class SlowQueueState:
|
|||||||
closed: bool = False
|
closed: bool = False
|
||||||
|
|
||||||
|
|
||||||
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
|
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
|
||||||
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
|
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
|
||||||
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
|
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
|
||||||
|
|
||||||
|
|
||||||
async def start_background_assets_scan():
|
|
||||||
await fast_reconcile_and_kickoff(progress_cb=_console_cb)
|
|
||||||
|
|
||||||
|
|
||||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||||
scans = []
|
scans = []
|
||||||
for root in ALLOWED_ROOTS:
|
for root in schemas_in.ALLOWED_ROOTS:
|
||||||
prog = PROGRESS_BY_ROOT.get(root)
|
prog = PROGRESS_BY_ROOT.get(root)
|
||||||
if not prog:
|
if not prog:
|
||||||
continue
|
continue
|
||||||
@ -75,83 +74,65 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
|||||||
return schemas_out.AssetScanStatusResponse(scans=scans)
|
return schemas_out.AssetScanStatusResponse(scans=scans)
|
||||||
|
|
||||||
|
|
||||||
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
|
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
|
||||||
"""Schedule scans for the provided roots; returns progress snapshots.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
- Only roots in {models, input, output} are accepted.
|
|
||||||
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
|
|
||||||
- Otherwise a new task is created and started immediately.
|
|
||||||
- Files with zero size are skipped.
|
|
||||||
"""
|
|
||||||
normalized: list[RootType] = []
|
|
||||||
seen = set()
|
|
||||||
for r in roots or []:
|
|
||||||
rr = r.strip().lower()
|
|
||||||
if rr in ALLOWED_ROOTS and rr not in seen:
|
|
||||||
normalized.append(rr) # type: ignore
|
|
||||||
seen.add(rr)
|
|
||||||
if not normalized:
|
|
||||||
normalized = list(ALLOWED_ROOTS) # schedule all by default
|
|
||||||
|
|
||||||
results: list[ScanProgress] = []
|
results: list[ScanProgress] = []
|
||||||
for root in normalized:
|
for root in roots:
|
||||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||||
results.append(PROGRESS_BY_ROOT[root])
|
results.append(PROGRESS_BY_ROOT[root])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
|
||||||
PROGRESS_BY_ROOT[root] = prog
|
PROGRESS_BY_ROOT[root] = prog
|
||||||
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
|
state = SlowQueueState(queue=asyncio.Queue())
|
||||||
|
SLOW_STATE_BY_ROOT[root] = state
|
||||||
RUNNING_TASKS[root] = asyncio.create_task(
|
RUNNING_TASKS[root] = asyncio.create_task(
|
||||||
_pipeline_for_root(root, prog, progress_cb=None),
|
_run_hash_verify_pipeline(root, prog, state),
|
||||||
name=f"asset-scan:{root}",
|
name=f"asset-scan:{root}",
|
||||||
)
|
)
|
||||||
results.append(prog)
|
results.append(prog)
|
||||||
return _status_response_for(results)
|
return _status_response_for(results)
|
||||||
|
|
||||||
|
|
||||||
async def fast_reconcile_and_kickoff(
|
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
|
||||||
roots: Optional[Sequence[str]] = None,
|
for r in roots:
|
||||||
*,
|
|
||||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None,
|
|
||||||
) -> schemas_out.AssetScanStatusResponse:
|
|
||||||
"""
|
|
||||||
Startup helper: do the fast pass now (so we know queue size),
|
|
||||||
start slow hashing in the background, return immediately.
|
|
||||||
"""
|
|
||||||
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
|
|
||||||
snaps: list[ScanProgress] = []
|
|
||||||
|
|
||||||
for root in normalized:
|
|
||||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
|
||||||
snaps.append(PROGRESS_BY_ROOT[root])
|
|
||||||
continue
|
|
||||||
|
|
||||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
|
||||||
PROGRESS_BY_ROOT[root] = prog
|
|
||||||
state = SlowQueueState(queue=asyncio.Queue())
|
|
||||||
SLOW_STATE_BY_ROOT[root] = state
|
|
||||||
|
|
||||||
prog.status = "running"
|
|
||||||
prog.started_at = time.time()
|
|
||||||
try:
|
try:
|
||||||
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
|
await _fast_db_consistency_pass(r)
|
||||||
except Exception as e:
|
except Exception as ex:
|
||||||
_append_error(prog, phase="fast", path="", message=str(e))
|
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
|
||||||
prog.status = "failed"
|
|
||||||
prog.finished_at = time.time()
|
paths: list[str] = []
|
||||||
LOGGER.exception("Fast reconcile failed for %s", root)
|
if "models" in roots:
|
||||||
snaps.append(prog)
|
paths.extend(collect_models_files())
|
||||||
|
if "input" in roots:
|
||||||
|
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||||
|
if "output" in roots:
|
||||||
|
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
try:
|
||||||
|
st = os.stat(p, follow_symlinks=True)
|
||||||
|
if not int(st.st_size or 0):
|
||||||
|
continue
|
||||||
|
size_bytes = int(st.st_size)
|
||||||
|
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
name, tags = get_name_and_tags_from_asset_path(p)
|
||||||
|
await _seed_one_async(p, size_bytes, mtime_ns, name, tags)
|
||||||
|
except OSError:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
|
|
||||||
RUNNING_TASKS[root] = asyncio.create_task(
|
async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None:
|
||||||
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
|
async with await create_session() as sess:
|
||||||
name=f"asset-hash:{root}",
|
await ensure_seed_for_path(
|
||||||
|
sess,
|
||||||
|
abs_path=p,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
info_name=name,
|
||||||
|
tags=tags,
|
||||||
|
owner_id="",
|
||||||
)
|
)
|
||||||
snaps.append(prog)
|
await sess.commit()
|
||||||
return _status_response_for(snaps)
|
|
||||||
|
|
||||||
|
|
||||||
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
||||||
@ -163,18 +144,15 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
|
|||||||
scan_id=progress.scan_id,
|
scan_id=progress.scan_id,
|
||||||
root=progress.root,
|
root=progress.root,
|
||||||
status=progress.status,
|
status=progress.status,
|
||||||
scheduled_at=_ts_to_iso(progress.scheduled_at),
|
scheduled_at=ts_to_iso(progress.scheduled_at),
|
||||||
started_at=_ts_to_iso(progress.started_at),
|
started_at=ts_to_iso(progress.started_at),
|
||||||
finished_at=_ts_to_iso(progress.finished_at),
|
finished_at=ts_to_iso(progress.finished_at),
|
||||||
discovered=progress.discovered,
|
discovered=progress.discovered,
|
||||||
processed=progress.processed,
|
processed=progress.processed,
|
||||||
slow_queue_total=progress.slow_queue_total,
|
|
||||||
slow_queue_finished=progress.slow_queue_finished,
|
|
||||||
file_errors=[
|
file_errors=[
|
||||||
schemas_out.AssetScanError(
|
schemas_out.AssetScanError(
|
||||||
path=e.get("path", ""),
|
path=e.get("path", ""),
|
||||||
message=e.get("message", ""),
|
message=e.get("message", ""),
|
||||||
phase=e.get("phase", "slow"),
|
|
||||||
at=e.get("at"),
|
at=e.get("at"),
|
||||||
)
|
)
|
||||||
for e in (progress.file_errors or [])
|
for e in (progress.file_errors or [])
|
||||||
@ -182,27 +160,100 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _pipeline_for_root(
|
async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||||
root: RootType,
|
"""Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime."""
|
||||||
prog: ScanProgress,
|
prefixes = prefixes_for_root(root)
|
||||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
if not prefixes:
|
||||||
) -> None:
|
return
|
||||||
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
|
|
||||||
SLOW_STATE_BY_ROOT[root] = state
|
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||||
|
|
||||||
|
async with await create_session() as sess:
|
||||||
|
rows = (
|
||||||
|
await sess.execute(
|
||||||
|
sa.select(
|
||||||
|
AssetCacheState.id,
|
||||||
|
AssetCacheState.mtime_ns,
|
||||||
|
AssetCacheState.needs_verify,
|
||||||
|
Asset.hash,
|
||||||
|
AssetCacheState.file_path,
|
||||||
|
)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(sa.or_(*conds))
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
to_set = []
|
||||||
|
to_clear = []
|
||||||
|
for sid, mtime_db, needs_verify, a_hash, fp in rows:
|
||||||
|
try:
|
||||||
|
st = os.stat(fp, follow_symlinks=True)
|
||||||
|
except OSError:
|
||||||
|
# Missing files are handled by missing-tag reconciliation later.
|
||||||
|
continue
|
||||||
|
|
||||||
|
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
if a_hash is not None:
|
||||||
|
if mtime_db is None or int(mtime_db) != int(actual_mtime_ns):
|
||||||
|
if not needs_verify:
|
||||||
|
to_set.append(sid)
|
||||||
|
else:
|
||||||
|
if needs_verify:
|
||||||
|
to_clear.append(sid)
|
||||||
|
|
||||||
|
if to_set:
|
||||||
|
await sess.execute(
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.id.in_(to_set))
|
||||||
|
.values(needs_verify=True)
|
||||||
|
)
|
||||||
|
if to_clear:
|
||||||
|
await sess.execute(
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.id.in_(to_clear))
|
||||||
|
.values(needs_verify=False)
|
||||||
|
)
|
||||||
|
await sess.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||||
prog.status = "running"
|
prog.status = "running"
|
||||||
prog.started_at = time.time()
|
prog.started_at = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await _reconcile_missing_tags_for_root(root, prog)
|
prefixes = prefixes_for_root(root)
|
||||||
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
|
|
||||||
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
|
await _refresh_verify_flags_for_root(root, prog)
|
||||||
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
|
|
||||||
|
# collect candidates from DB
|
||||||
|
async with await create_session() as sess:
|
||||||
|
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||||
|
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||||
|
# dedupe: prioritize verification first
|
||||||
|
seen = set()
|
||||||
|
ordered: list[int] = []
|
||||||
|
for lst in (verify_ids, unhashed_ids):
|
||||||
|
for sid in lst:
|
||||||
|
if sid not in seen:
|
||||||
|
seen.add(sid); ordered.append(sid)
|
||||||
|
|
||||||
|
prog.discovered = len(ordered)
|
||||||
|
|
||||||
|
# queue up work
|
||||||
|
for sid in ordered:
|
||||||
|
await state.queue.put(sid)
|
||||||
|
state.closed = True
|
||||||
|
_start_state_workers(root, prog, state)
|
||||||
|
await _await_state_workers_then_finish(root, prog, state)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
prog.status = "cancelled"
|
prog.status = "cancelled"
|
||||||
raise
|
raise
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
_append_error(prog, phase="slow", path="", message=str(exc))
|
_append_error(prog, path="", message=str(exc))
|
||||||
prog.status = "failed"
|
prog.status = "failed"
|
||||||
prog.finished_at = time.time()
|
prog.finished_at = time.time()
|
||||||
LOGGER.exception("Asset scan failed for %s", root)
|
LOGGER.exception("Asset scan failed for %s", root)
|
||||||
@ -210,110 +261,13 @@ async def _pipeline_for_root(
|
|||||||
RUNNING_TASKS.pop(root, None)
|
RUNNING_TASKS.pop(root, None)
|
||||||
|
|
||||||
|
|
||||||
async def _fast_reconcile_into_queue(
|
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||||
root: RootType,
|
|
||||||
prog: ScanProgress,
|
|
||||||
state: SlowQueueState,
|
|
||||||
*,
|
|
||||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
"""
|
||||||
Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
|
Detect missing files quickly and toggle 'missing' tag per asset_id.
|
||||||
and queue the rest for slow hashing.
|
|
||||||
"""
|
|
||||||
if root == "models":
|
|
||||||
files = _collect_models_files()
|
|
||||||
preset_discovered = _count_nonzero_in_list(files)
|
|
||||||
files_iter = asyncio.Queue()
|
|
||||||
for p in files:
|
|
||||||
await files_iter.put(p)
|
|
||||||
await files_iter.put(None) # sentinel for our local draining loop
|
|
||||||
elif root == "input":
|
|
||||||
base = folder_paths.get_input_directory()
|
|
||||||
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
|
|
||||||
files_iter = await _queue_tree_files(base)
|
|
||||||
elif root == "output":
|
|
||||||
base = folder_paths.get_output_directory()
|
|
||||||
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
|
|
||||||
files_iter = await _queue_tree_files(base)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Unsupported root: {root}")
|
|
||||||
|
|
||||||
prog.discovered = int(preset_discovered or 0)
|
Rules:
|
||||||
|
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
|
||||||
queued = 0
|
- We consider ALL cache states of the asset (across roots) before tagging.
|
||||||
checked = 0
|
|
||||||
clean = 0
|
|
||||||
|
|
||||||
async with await create_session() as sess:
|
|
||||||
while True:
|
|
||||||
item = await files_iter.get()
|
|
||||||
files_iter.task_done()
|
|
||||||
if item is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
abs_path = item
|
|
||||||
checked += 1
|
|
||||||
|
|
||||||
# Stat; skip empty/unreadable
|
|
||||||
try:
|
|
||||||
st = os.stat(abs_path, follow_symlinks=True)
|
|
||||||
if not st.st_size:
|
|
||||||
continue
|
|
||||||
size_bytes = int(st.st_size)
|
|
||||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
|
||||||
except OSError as e:
|
|
||||||
_append_error(prog, phase="fast", path=abs_path, message=str(e))
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
known = await check_fs_asset_exists_quick(
|
|
||||||
sess,
|
|
||||||
file_path=abs_path,
|
|
||||||
size_bytes=size_bytes,
|
|
||||||
mtime_ns=mtime_ns,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
_append_error(prog, phase="fast", path=abs_path, message=str(e))
|
|
||||||
known = False
|
|
||||||
|
|
||||||
if known:
|
|
||||||
clean += 1
|
|
||||||
prog.processed += 1
|
|
||||||
else:
|
|
||||||
await state.queue.put(abs_path)
|
|
||||||
queued += 1
|
|
||||||
prog.slow_queue_total += 1
|
|
||||||
|
|
||||||
if progress_cb:
|
|
||||||
progress_cb(root, "fast", prog.processed, False, {
|
|
||||||
"checked": checked,
|
|
||||||
"clean": clean,
|
|
||||||
"queued": queued,
|
|
||||||
"discovered": prog.discovered,
|
|
||||||
})
|
|
||||||
|
|
||||||
prog._fast_total_seen = checked
|
|
||||||
prog._fast_clean = clean
|
|
||||||
|
|
||||||
if progress_cb:
|
|
||||||
progress_cb(root, "fast", prog.processed, True, {
|
|
||||||
"checked": checked,
|
|
||||||
"clean": clean,
|
|
||||||
"queued": queued,
|
|
||||||
"discovered": prog.discovered,
|
|
||||||
})
|
|
||||||
state.closed = True
|
|
||||||
|
|
||||||
|
|
||||||
async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None:
|
|
||||||
"""
|
|
||||||
Logic for detecting missing Assets files:
|
|
||||||
- Clear 'missing' only if at least one cached path passes fast check:
|
|
||||||
exists AND mtime_ns matches AND size matches.
|
|
||||||
- Otherwise set 'missing'.
|
|
||||||
Files that exist but fail fast check will be slow-hashed by the normal pipeline,
|
|
||||||
and ingest_fs_asset will clear 'missing' if they truly match.
|
|
||||||
"""
|
"""
|
||||||
if root == "models":
|
if root == "models":
|
||||||
bases: list[str] = []
|
bases: list[str] = []
|
||||||
@ -326,232 +280,217 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with await create_session() as sess:
|
async with await create_session() as sess:
|
||||||
|
# state + hash + size for the current root
|
||||||
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
||||||
|
|
||||||
by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}}
|
# Track fast_ok within the scanned root and whether the asset is hashed
|
||||||
for state, size_db in rows:
|
by_asset: dict[str, dict[str, bool]] = {}
|
||||||
h = state.asset_hash
|
for state, a_hash, size_db in rows:
|
||||||
acc = by_hash.get(h)
|
aid = state.asset_id
|
||||||
|
acc = by_asset.get(aid)
|
||||||
if acc is None:
|
if acc is None:
|
||||||
acc = {"any_fast_ok": False}
|
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
|
||||||
by_hash[h] = acc
|
by_asset[aid] = acc
|
||||||
try:
|
try:
|
||||||
st = os.stat(state.file_path, follow_symlinks=True)
|
st = os.stat(state.file_path, follow_symlinks=True)
|
||||||
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
fast_ok = False
|
fast_ok = False
|
||||||
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
|
if acc["hashed"]:
|
||||||
if int(size_db) > 0 and int(st.st_size) == int(size_db):
|
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
|
||||||
fast_ok = True
|
if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]):
|
||||||
|
fast_ok = True
|
||||||
if fast_ok:
|
if fast_ok:
|
||||||
acc["any_fast_ok"] = True
|
acc["any_fast_ok_here"] = True
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass # not fast_ok
|
pass
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
_append_error(prog, phase="fast", path=state.file_path, message=str(e))
|
_append_error(prog, path=state.file_path, message=str(e))
|
||||||
|
|
||||||
for h, acc in by_hash.items():
|
# Decide per asset, considering ALL its states (not just this root)
|
||||||
|
for aid, acc in by_asset.items():
|
||||||
try:
|
try:
|
||||||
if acc["any_fast_ok"]:
|
if not acc["hashed"]:
|
||||||
await remove_missing_tag_for_asset_hash(sess, asset_hash=h)
|
# Never tag seed assets as missing
|
||||||
|
continue
|
||||||
|
|
||||||
|
any_fast_ok_global = acc["any_fast_ok_here"]
|
||||||
|
if not any_fast_ok_global:
|
||||||
|
# Check other states outside this root
|
||||||
|
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
|
||||||
|
for st in others:
|
||||||
|
try:
|
||||||
|
s = os.stat(st.file_path, follow_symlinks=True)
|
||||||
|
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
|
||||||
|
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
|
||||||
|
if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]:
|
||||||
|
any_fast_ok_global = True
|
||||||
|
break
|
||||||
|
except OSError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if any_fast_ok_global:
|
||||||
|
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||||
else:
|
else:
|
||||||
await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic")
|
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
_append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}")
|
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
|
||||||
|
|
||||||
await sess.commit()
|
await sess.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}")
|
_append_error(prog, path="", message=f"reconcile failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
def _start_slow_workers(
|
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||||
root: RootType,
|
|
||||||
prog: ScanProgress,
|
|
||||||
state: SlowQueueState,
|
|
||||||
*,
|
|
||||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
|
||||||
) -> None:
|
|
||||||
if state.workers:
|
if state.workers:
|
||||||
return
|
return
|
||||||
|
|
||||||
async def _worker(_worker_id: int):
|
async def _worker(_wid: int):
|
||||||
while True:
|
while True:
|
||||||
item = await state.queue.get()
|
sid = await state.queue.get()
|
||||||
try:
|
try:
|
||||||
if item is None:
|
if sid is None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
|
async with await create_session() as sess:
|
||||||
except Exception as e:
|
# Optional: fetch path for better error messages
|
||||||
_append_error(prog, phase="slow", path=item, message=str(e))
|
st = await sess.get(AssetCacheState, sid)
|
||||||
|
try:
|
||||||
|
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
|
||||||
|
await sess.commit()
|
||||||
|
except Exception as e:
|
||||||
|
path = st.file_path if st else f"state:{sid}"
|
||||||
|
_append_error(prog, path=path, message=str(e))
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
finally:
|
finally:
|
||||||
# Slow queue finished for this item; also counts toward overall processed
|
|
||||||
prog.slow_queue_finished += 1
|
|
||||||
prog.processed += 1
|
prog.processed += 1
|
||||||
if progress_cb:
|
|
||||||
progress_cb(root, "slow", prog.processed, False, {
|
|
||||||
"slow_queue_finished": prog.slow_queue_finished,
|
|
||||||
"slow_queue_total": prog.slow_queue_total,
|
|
||||||
})
|
|
||||||
finally:
|
finally:
|
||||||
state.queue.task_done()
|
state.queue.task_done()
|
||||||
|
|
||||||
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
|
state.workers = [
|
||||||
|
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
|
||||||
|
for i in range(SLOW_HASH_CONCURRENCY)
|
||||||
|
]
|
||||||
|
|
||||||
async def _close_when_empty():
|
async def _close_when_ready():
|
||||||
# When the fast phase closed the queue, push sentinels to end workers
|
|
||||||
while not state.closed:
|
while not state.closed:
|
||||||
await asyncio.sleep(0.05)
|
await asyncio.sleep(0.05)
|
||||||
for _ in range(SLOW_HASH_CONCURRENCY):
|
for _ in range(SLOW_HASH_CONCURRENCY):
|
||||||
await state.queue.put(None)
|
await state.queue.put(None)
|
||||||
|
|
||||||
asyncio.create_task(_close_when_empty())
|
asyncio.create_task(_close_when_ready())
|
||||||
|
|
||||||
|
|
||||||
async def _await_workers_then_finish(
|
async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||||
root: RootType,
|
|
||||||
prog: ScanProgress,
|
|
||||||
state: SlowQueueState,
|
|
||||||
*,
|
|
||||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
|
||||||
) -> None:
|
|
||||||
if state.workers:
|
if state.workers:
|
||||||
await asyncio.gather(*state.workers, return_exceptions=True)
|
await asyncio.gather(*state.workers, return_exceptions=True)
|
||||||
await _reconcile_missing_tags_for_root(root, prog)
|
await _reconcile_missing_tags_for_root(root, prog)
|
||||||
prog.finished_at = time.time()
|
prog.finished_at = time.time()
|
||||||
prog.status = "completed"
|
prog.status = "completed"
|
||||||
if progress_cb:
|
|
||||||
progress_cb(root, "slow", prog.processed, True, {
|
|
||||||
"slow_queue_finished": prog.slow_queue_finished,
|
|
||||||
"slow_queue_total": prog.slow_queue_total,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_models_files() -> list[str]:
|
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
|
||||||
"""Collect absolute file paths from configured model buckets under models_dir."""
|
|
||||||
out: list[str] = []
|
|
||||||
for folder_name, bases in get_comfy_models_folders():
|
|
||||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
|
||||||
for rel_path in rel_files:
|
|
||||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
|
||||||
if not abs_path:
|
|
||||||
continue
|
|
||||||
abs_path = os.path.abspath(abs_path)
|
|
||||||
# ensure within allowed bases
|
|
||||||
allowed = False
|
|
||||||
for b in bases:
|
|
||||||
base_abs = os.path.abspath(b)
|
|
||||||
with contextlib.suppress(Exception):
|
|
||||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
|
||||||
allowed = True
|
|
||||||
break
|
|
||||||
if allowed:
|
|
||||||
out.append(abs_path)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int:
|
|
||||||
if not os.path.isdir(base_abs):
|
|
||||||
return 0
|
|
||||||
total = 0
|
|
||||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
|
||||||
if not only_nonzero:
|
|
||||||
total += len(filenames)
|
|
||||||
else:
|
|
||||||
for name in filenames:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
st = os.stat(os.path.join(dirpath, name), follow_symlinks=True)
|
|
||||||
if st.st_size:
|
|
||||||
total += 1
|
|
||||||
return total
|
|
||||||
|
|
||||||
|
|
||||||
def _count_nonzero_in_list(paths: list[str]) -> int:
|
|
||||||
cnt = 0
|
|
||||||
for p in paths:
|
|
||||||
with contextlib.suppress(OSError):
|
|
||||||
st = os.stat(p, follow_symlinks=True)
|
|
||||||
if st.st_size:
|
|
||||||
cnt += 1
|
|
||||||
return cnt
|
|
||||||
|
|
||||||
|
|
||||||
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
|
|
||||||
"""
|
|
||||||
Walk base_dir in a worker thread and return a queue prefilled with all paths,
|
|
||||||
terminated by a single None sentinel for the draining loop in fast reconcile.
|
|
||||||
"""
|
|
||||||
q: asyncio.Queue = asyncio.Queue()
|
|
||||||
base_abs = os.path.abspath(base_dir)
|
|
||||||
if not os.path.isdir(base_abs):
|
|
||||||
await q.put(None)
|
|
||||||
return q
|
|
||||||
|
|
||||||
def _walk_list():
|
|
||||||
paths: list[str] = []
|
|
||||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
|
||||||
for name in filenames:
|
|
||||||
paths.append(os.path.abspath(os.path.join(dirpath, name)))
|
|
||||||
return paths
|
|
||||||
|
|
||||||
for p in await asyncio.to_thread(_walk_list):
|
|
||||||
await q.put(p)
|
|
||||||
await q.put(None)
|
|
||||||
return q
|
|
||||||
|
|
||||||
|
|
||||||
def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
|
|
||||||
prog.file_errors.append({
|
prog.file_errors.append({
|
||||||
"path": path,
|
"path": path,
|
||||||
"message": message,
|
"message": message,
|
||||||
"phase": phase,
|
"at": ts_to_iso(time.time()),
|
||||||
"at": _ts_to_iso(time.time()),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None:
|
||||||
if ts is None:
|
"""
|
||||||
return None
|
Quick pass over asset_cache_state for `root`:
|
||||||
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
|
- If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos.
|
||||||
try:
|
- If file missing and Asset.hash is NOT NULL:
|
||||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
* If at least one state for this Asset is fast-ok, delete the missing state.
|
||||||
except Exception:
|
* If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset.
|
||||||
return None
|
- If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag.
|
||||||
|
"""
|
||||||
|
prefixes = prefixes_for_root(root)
|
||||||
|
if not prefixes:
|
||||||
|
return
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||||
|
|
||||||
def _new_scan_id(root: RootType) -> str:
|
async with await create_session() as sess:
|
||||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
if not conds:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows = (
|
||||||
def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict):
|
await sess.execute(
|
||||||
if phase == "fast":
|
sa.select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||||
if finished:
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
logging.info(
|
.where(sa.or_(*conds))
|
||||||
"[assets][%s] fast done: processed=%s/%s queued=%s",
|
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||||
root,
|
|
||||||
total_processed,
|
|
||||||
e["discovered"],
|
|
||||||
e["queued"],
|
|
||||||
)
|
|
||||||
elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress
|
|
||||||
logging.info(
|
|
||||||
"[assets][%s] fast progress: processed=%s/%s",
|
|
||||||
root,
|
|
||||||
total_processed,
|
|
||||||
e["discovered"],
|
|
||||||
)
|
|
||||||
elif phase == "slow":
|
|
||||||
if finished:
|
|
||||||
if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0):
|
|
||||||
logging.info(
|
|
||||||
"[assets][%s] slow done: %s/%s",
|
|
||||||
root,
|
|
||||||
e.get("slow_queue_finished", 0),
|
|
||||||
e.get("slow_queue_total", 0),
|
|
||||||
)
|
|
||||||
elif e.get('slow_queue_finished', 0) % 3 == 0:
|
|
||||||
logging.info(
|
|
||||||
"[assets][%s] slow progress: %s/%s",
|
|
||||||
root,
|
|
||||||
e.get("slow_queue_finished", 0),
|
|
||||||
e.get("slow_queue_total", 0),
|
|
||||||
)
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# Group by asset_id with status per state
|
||||||
|
by_asset: dict[str, dict] = {}
|
||||||
|
for st, a_hash, a_size in rows:
|
||||||
|
aid = st.asset_id
|
||||||
|
acc = by_asset.get(aid)
|
||||||
|
if acc is None:
|
||||||
|
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||||
|
by_asset[aid] = acc
|
||||||
|
|
||||||
|
exists = False
|
||||||
|
fast_ok = False
|
||||||
|
try:
|
||||||
|
s = os.stat(st.file_path, follow_symlinks=True)
|
||||||
|
exists = True
|
||||||
|
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
|
||||||
|
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
|
||||||
|
if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]:
|
||||||
|
fast_ok = True
|
||||||
|
except FileNotFoundError:
|
||||||
|
exists = False
|
||||||
|
except OSError as ex:
|
||||||
|
exists = False
|
||||||
|
LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex)
|
||||||
|
|
||||||
|
acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok})
|
||||||
|
|
||||||
|
# Apply actions
|
||||||
|
for aid, acc in by_asset.items():
|
||||||
|
a_hash = acc["hash"]
|
||||||
|
states = acc["states"]
|
||||||
|
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||||
|
all_missing = all(not s["exists"] for s in states)
|
||||||
|
missing_states = [s["obj"] for s in states if not s["exists"]]
|
||||||
|
|
||||||
|
if a_hash is None:
|
||||||
|
# Seed asset: if all states gone (and in practice there is only one), remove the whole Asset
|
||||||
|
if states and all_missing:
|
||||||
|
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||||
|
asset = await sess.get(Asset, aid)
|
||||||
|
if asset:
|
||||||
|
await sess.delete(asset)
|
||||||
|
# else leave it for the slow scan to verify/rehash
|
||||||
|
else:
|
||||||
|
if any_fast_ok:
|
||||||
|
# Remove 'missing' and delete just the stale state rows
|
||||||
|
for st in missing_states:
|
||||||
|
try:
|
||||||
|
await sess.delete(await sess.get(AssetCacheState, st.id))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# No fast-ok path: mark as missing
|
||||||
|
try:
|
||||||
|
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
await sess.flush()
|
||||||
|
await sess.commit()
|
||||||
|
|||||||
@ -1,186 +0,0 @@
|
|||||||
from decimal import Decimal
|
|
||||||
from typing import Any, Sequence, Optional, Iterable
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy import select, exists
|
|
||||||
|
|
||||||
from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta
|
|
||||||
from .._assets_helpers import normalize_tags
|
|
||||||
|
|
||||||
|
|
||||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
|
||||||
wanted = normalize_tags(list(names))
|
|
||||||
if not wanted:
|
|
||||||
return []
|
|
||||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
|
||||||
by_name = {t.name: t for t in existing}
|
|
||||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
|
||||||
if to_create:
|
|
||||||
session.add_all(to_create)
|
|
||||||
await session.flush()
|
|
||||||
by_name.update({t.name: t for t in to_create})
|
|
||||||
return [by_name[n] for n in wanted]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_tag_filters(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
include_tags: Optional[Sequence[str]],
|
|
||||||
exclude_tags: Optional[Sequence[str]],
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
||||||
include_tags = normalize_tags(include_tags)
|
|
||||||
exclude_tags = normalize_tags(exclude_tags)
|
|
||||||
|
|
||||||
if include_tags:
|
|
||||||
for tag_name in include_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
exists().where(
|
|
||||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
||||||
& (AssetInfoTag.tag_name == tag_name)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if exclude_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
~exists().where(
|
|
||||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
|
||||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def apply_metadata_filter(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
metadata_filter: Optional[dict],
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""Apply metadata filters using the projection table asset_info_meta.
|
|
||||||
|
|
||||||
Semantics:
|
|
||||||
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
|
||||||
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
|
||||||
- For list values: ANY-of the list elements matches (EXISTS for any).
|
|
||||||
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
|
||||||
"""
|
|
||||||
if not metadata_filter:
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
||||||
return sa.exists().where(
|
|
||||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
||||||
AssetInfoMeta.key == key,
|
|
||||||
*preds,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
||||||
# Missing OR null:
|
|
||||||
if value is None:
|
|
||||||
# either: no row for key OR a row for key with explicit null
|
|
||||||
no_row_for_key = sa.not_(
|
|
||||||
sa.exists().where(
|
|
||||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
|
||||||
AssetInfoMeta.key == key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
null_row = _exists_for_pred(
|
|
||||||
key,
|
|
||||||
AssetInfoMeta.val_json.is_(None),
|
|
||||||
AssetInfoMeta.val_str.is_(None),
|
|
||||||
AssetInfoMeta.val_num.is_(None),
|
|
||||||
AssetInfoMeta.val_bool.is_(None),
|
|
||||||
)
|
|
||||||
return sa.or_(no_row_for_key, null_row)
|
|
||||||
|
|
||||||
# Typed scalar matches:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
|
||||||
if isinstance(value, (int, float, Decimal)):
|
|
||||||
# store as Decimal for equality against NUMERIC(38,10)
|
|
||||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
||||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
|
||||||
if isinstance(value, str):
|
|
||||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
|
||||||
|
|
||||||
# Complex: compare JSON (no index, but supported)
|
|
||||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
|
||||||
|
|
||||||
for k, v in metadata_filter.items():
|
|
||||||
if isinstance(v, list):
|
|
||||||
# ANY-of (exists for any element)
|
|
||||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
||||||
if ors:
|
|
||||||
stmt = stmt.where(sa.or_(*ors))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def is_scalar(v: Any) -> bool:
|
|
||||||
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
|
||||||
return True
|
|
||||||
if isinstance(v, bool):
|
|
||||||
return True
|
|
||||||
if isinstance(v, (int, float, Decimal, str)):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def project_kv(key: str, value: Any) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Turn a metadata key/value into one or more projection rows:
|
|
||||||
- scalar -> one row (ordinal=0) in the proper typed column
|
|
||||||
- list of scalars -> one row per element with ordinal=i
|
|
||||||
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
|
||||||
- None -> single row with all value columns NULL
|
|
||||||
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
|
||||||
"""
|
|
||||||
rows: list[dict] = []
|
|
||||||
|
|
||||||
def _null_row(ordinal: int) -> dict:
|
|
||||||
return {
|
|
||||||
"key": key, "ordinal": ordinal,
|
|
||||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
|
||||||
}
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
rows.append(_null_row(0))
|
|
||||||
return rows
|
|
||||||
|
|
||||||
if is_scalar(value):
|
|
||||||
if isinstance(value, bool):
|
|
||||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
|
||||||
elif isinstance(value, (int, float, Decimal)):
|
|
||||||
# store numeric; SQLAlchemy will coerce to Numeric
|
|
||||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
||||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
|
||||||
elif isinstance(value, str):
|
|
||||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
|
||||||
else:
|
|
||||||
# Fallback to json
|
|
||||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
|
||||||
return rows
|
|
||||||
|
|
||||||
if isinstance(value, list):
|
|
||||||
if all(is_scalar(x) for x in value):
|
|
||||||
for i, x in enumerate(value):
|
|
||||||
if x is None:
|
|
||||||
rows.append(_null_row(i))
|
|
||||||
elif isinstance(x, bool):
|
|
||||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
|
||||||
elif isinstance(x, (int, float, Decimal)):
|
|
||||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
|
||||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
|
||||||
elif isinstance(x, str):
|
|
||||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
|
||||||
else:
|
|
||||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
|
||||||
return rows
|
|
||||||
# list contains objects -> one val_json per element
|
|
||||||
for i, x in enumerate(value):
|
|
||||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
|
||||||
return rows
|
|
||||||
|
|
||||||
# Dict or any other structure -> single json row
|
|
||||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
|
||||||
return rows
|
|
||||||
@ -4,14 +4,20 @@ import shutil
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from comfy.cli_args import args
|
|
||||||
from alembic import command
|
from alembic import command
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from alembic.runtime.migration import MigrationContext
|
from alembic.runtime.migration import MigrationContext
|
||||||
from alembic.script import ScriptDirectory
|
from alembic.script import ScriptDirectory
|
||||||
from sqlalchemy import create_engine, text
|
from sqlalchemy import create_engine, text
|
||||||
from sqlalchemy.engine import make_url
|
from sqlalchemy.engine import make_url
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import (
|
||||||
|
AsyncEngine,
|
||||||
|
AsyncSession,
|
||||||
|
async_sessionmaker,
|
||||||
|
create_async_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
LOGGER = logging.getLogger(__name__)
|
||||||
ENGINE: Optional[AsyncEngine] = None
|
ENGINE: Optional[AsyncEngine] = None
|
||||||
|
|||||||
23
app/database/helpers/__init__.py
Normal file
23
app/database/helpers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from .filters import apply_metadata_filter, apply_tag_filters
|
||||||
|
from .ownership import visible_owner_clause
|
||||||
|
from .projection import is_scalar, project_kv
|
||||||
|
from .tags import (
|
||||||
|
add_missing_tag_for_asset_hash,
|
||||||
|
add_missing_tag_for_asset_id,
|
||||||
|
ensure_tags_exist,
|
||||||
|
remove_missing_tag_for_asset_hash,
|
||||||
|
remove_missing_tag_for_asset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"apply_tag_filters",
|
||||||
|
"apply_metadata_filter",
|
||||||
|
"is_scalar",
|
||||||
|
"project_kv",
|
||||||
|
"ensure_tags_exist",
|
||||||
|
"add_missing_tag_for_asset_id",
|
||||||
|
"add_missing_tag_for_asset_hash",
|
||||||
|
"remove_missing_tag_for_asset_id",
|
||||||
|
"remove_missing_tag_for_asset_hash",
|
||||||
|
"visible_owner_clause",
|
||||||
|
]
|
||||||
87
app/database/helpers/filters.py
Normal file
87
app/database/helpers/filters.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import exists
|
||||||
|
|
||||||
|
from ..._assets_helpers import normalize_tags
|
||||||
|
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Optional[Sequence[str]],
|
||||||
|
exclude_tags: Optional[Sequence[str]],
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = normalize_tags(include_tags)
|
||||||
|
exclude_tags = normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||||
|
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: Optional[dict],
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply filters using asset_info_meta projection table."""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
return sa.exists().where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
if value is None:
|
||||||
|
no_row_for_key = sa.not_(
|
||||||
|
sa.exists().where(
|
||||||
|
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||||
|
AssetInfoMeta.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
null_row = _exists_for_pred(
|
||||||
|
key,
|
||||||
|
AssetInfoMeta.val_json.is_(None),
|
||||||
|
AssetInfoMeta.val_str.is_(None),
|
||||||
|
AssetInfoMeta.val_num.is_(None),
|
||||||
|
AssetInfoMeta.val_bool.is_(None),
|
||||||
|
)
|
||||||
|
return sa.or_(no_row_for_key, null_row)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
from decimal import Decimal
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||||
|
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
12
app/database/helpers/ownership.py
Normal file
12
app/database/helpers/ownership.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from ..models import AssetInfo
|
||||||
|
|
||||||
|
|
||||||
|
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||||
|
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||||
|
|
||||||
|
owner_id = (owner_id or "").strip()
|
||||||
|
if owner_id == "":
|
||||||
|
return AssetInfo.owner_id == ""
|
||||||
|
return AssetInfo.owner_id.in_(["", owner_id])
|
||||||
64
app/database/helpers/projection.py
Normal file
64
app/database/helpers/projection.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
|
||||||
|
def is_scalar(v):
|
||||||
|
if v is None:
|
||||||
|
return True
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return True
|
||||||
|
if isinstance(v, (int, float, Decimal, str)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def project_kv(key: str, value):
|
||||||
|
"""
|
||||||
|
Turn a metadata key/value into typed projection rows.
|
||||||
|
Returns list[dict] with keys:
|
||||||
|
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||||
|
"""
|
||||||
|
rows: list[dict] = []
|
||||||
|
|
||||||
|
def _null_row(ordinal: int) -> dict:
|
||||||
|
return {
|
||||||
|
"key": key, "ordinal": ordinal,
|
||||||
|
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
rows.append(_null_row(0))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if is_scalar(value):
|
||||||
|
if isinstance(value, bool):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||||
|
elif isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||||
|
elif isinstance(value, str):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if all(is_scalar(x) for x in value):
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
if x is None:
|
||||||
|
rows.append(_null_row(i))
|
||||||
|
elif isinstance(x, bool):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||||
|
elif isinstance(x, (int, float, Decimal)):
|
||||||
|
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||||
|
elif isinstance(x, str):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
102
app/database/helpers/tags.py
Normal file
102
app/database/helpers/tags.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ..._assets_helpers import normalize_tags
|
||||||
|
from ..models import Asset, AssetInfo, AssetInfoTag, Tag
|
||||||
|
from ..timeutil import utcnow
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||||
|
wanted = normalize_tags(list(names))
|
||||||
|
if not wanted:
|
||||||
|
return []
|
||||||
|
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||||
|
by_name = {t.name: t for t in existing}
|
||||||
|
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||||
|
if to_create:
|
||||||
|
session.add_all(to_create)
|
||||||
|
await session.flush()
|
||||||
|
by_name.update({t.name: t for t in to_create})
|
||||||
|
return [by_name[n] for n in wanted]
|
||||||
|
|
||||||
|
|
||||||
|
async def add_missing_tag_for_asset_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
origin: str = "automatic",
|
||||||
|
) -> int:
|
||||||
|
"""Ensure every AssetInfo for asset_id has 'missing' tag."""
|
||||||
|
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||||
|
if not ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
existing = {
|
||||||
|
asset_info_id
|
||||||
|
for (asset_info_id,) in (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfoTag.asset_info_id).where(
|
||||||
|
AssetInfoTag.asset_info_id.in_(ids),
|
||||||
|
AssetInfoTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
to_add = [i for i in ids if i not in existing]
|
||||||
|
if not to_add:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
now = utcnow()
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
|
||||||
|
for i in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
return len(to_add)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_missing_tag_for_asset_hash(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
origin: str = "automatic",
|
||||||
|
) -> int:
|
||||||
|
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
return 0
|
||||||
|
return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_missing_tag_for_asset_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Remove the 'missing' tag from all AssetInfos for asset_id."""
|
||||||
|
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||||
|
if not ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
res = await session.execute(
|
||||||
|
delete(AssetInfoTag).where(
|
||||||
|
AssetInfoTag.asset_info_id.in_(ids),
|
||||||
|
AssetInfoTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
return int(res.rowcount or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_missing_tag_for_asset_hash(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
) -> int:
|
||||||
|
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
return 0
|
||||||
|
return await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||||
@ -1,27 +1,26 @@
|
|||||||
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
import uuid
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Integer,
|
JSON,
|
||||||
BigInteger,
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
CheckConstraint,
|
||||||
DateTime,
|
DateTime,
|
||||||
ForeignKey,
|
ForeignKey,
|
||||||
Index,
|
Index,
|
||||||
UniqueConstraint,
|
Integer,
|
||||||
JSON,
|
Numeric,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
CheckConstraint,
|
UniqueConstraint,
|
||||||
Numeric,
|
|
||||||
Boolean,
|
|
||||||
)
|
)
|
||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
|
||||||
|
|
||||||
from .timeutil import utcnow
|
from .timeutil import utcnow
|
||||||
|
|
||||||
|
|
||||||
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +45,8 @@ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
|||||||
class Asset(Base):
|
class Asset(Base):
|
||||||
__tablename__ = "assets"
|
__tablename__ = "assets"
|
||||||
|
|
||||||
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
|
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||||
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
@ -56,8 +56,8 @@ class Asset(Base):
|
|||||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||||
"AssetInfo",
|
"AssetInfo",
|
||||||
back_populates="asset",
|
back_populates="asset",
|
||||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||||
foreign_keys=lambda: [AssetInfo.asset_hash],
|
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||||
cascade="all,delete-orphan",
|
cascade="all,delete-orphan",
|
||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
@ -65,8 +65,8 @@ class Asset(Base):
|
|||||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||||
"AssetInfo",
|
"AssetInfo",
|
||||||
back_populates="preview_asset",
|
back_populates="preview_asset",
|
||||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||||
foreign_keys=lambda: [AssetInfo.preview_hash],
|
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||||
viewonly=True,
|
viewonly=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,36 +76,32 @@ class Asset(Base):
|
|||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
locations: Mapped[list["AssetLocation"]] = relationship(
|
|
||||||
back_populates="asset",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
passive_deletes=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_assets_mime_type", "mime_type"),
|
Index("ix_assets_mime_type", "mime_type"),
|
||||||
|
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||||
return to_dict(self, include_none=include_none)
|
return to_dict(self, include_none=include_none)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<Asset hash={self.hash[:12]}>"
|
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||||
|
|
||||||
|
|
||||||
class AssetCacheState(Base):
|
class AssetCacheState(Base):
|
||||||
__tablename__ = "asset_cache_state"
|
__tablename__ = "asset_cache_state"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||||
|
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||||
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
|
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||||
)
|
)
|
||||||
@ -114,27 +110,7 @@ class AssetCacheState(Base):
|
|||||||
return to_dict(self, include_none=include_none)
|
return to_dict(self, include_none=include_none)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
|
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||||
|
|
||||||
|
|
||||||
class AssetLocation(Base):
|
|
||||||
__tablename__ = "asset_locations"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
|
||||||
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
|
|
||||||
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
|
|
||||||
expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
|
||||||
etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
|
||||||
last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
|
||||||
|
|
||||||
asset: Mapped["Asset"] = relationship(back_populates="locations")
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
|
|
||||||
Index("ix_asset_locations_hash", "asset_hash"),
|
|
||||||
Index("ix_asset_locations_provider", "provider"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AssetInfo(Base):
|
class AssetInfo(Base):
|
||||||
@ -143,31 +119,23 @@ class AssetInfo(Base):
|
|||||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
asset_hash: Mapped[str] = mapped_column(
|
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||||
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||||
)
|
|
||||||
preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
|
||||||
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
DateTime(timezone=False), nullable=False, default=utcnow
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
)
|
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=False), nullable=False, default=utcnow
|
|
||||||
)
|
|
||||||
last_access_time: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=False), nullable=False, default=utcnow
|
|
||||||
)
|
|
||||||
|
|
||||||
# Relationships
|
|
||||||
asset: Mapped[Asset] = relationship(
|
asset: Mapped[Asset] = relationship(
|
||||||
"Asset",
|
"Asset",
|
||||||
back_populates="infos",
|
back_populates="infos",
|
||||||
foreign_keys=[asset_hash],
|
foreign_keys=[asset_id],
|
||||||
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
preview_asset: Mapped[Optional[Asset]] = relationship(
|
preview_asset: Mapped[Optional[Asset]] = relationship(
|
||||||
"Asset",
|
"Asset",
|
||||||
back_populates="preview_of",
|
back_populates="preview_of",
|
||||||
foreign_keys=[preview_hash],
|
foreign_keys=[preview_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||||
@ -186,16 +154,16 @@ class AssetInfo(Base):
|
|||||||
tags: Mapped[list["Tag"]] = relationship(
|
tags: Mapped[list["Tag"]] = relationship(
|
||||||
secondary="asset_info_tags",
|
secondary="asset_info_tags",
|
||||||
back_populates="asset_infos",
|
back_populates="asset_infos",
|
||||||
lazy="joined",
|
lazy="selectin",
|
||||||
viewonly=True,
|
viewonly=True,
|
||||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||||
)
|
)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
|
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||||
Index("ix_assets_info_owner_id", "owner_id"),
|
Index("ix_assets_info_owner_id", "owner_id"),
|
||||||
Index("ix_assets_info_asset_hash", "asset_hash"),
|
Index("ix_assets_info_asset_id", "asset_id"),
|
||||||
Index("ix_assets_info_name", "name"),
|
Index("ix_assets_info_name", "name"),
|
||||||
Index("ix_assets_info_created_at", "created_at"),
|
Index("ix_assets_info_created_at", "created_at"),
|
||||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||||
@ -207,7 +175,7 @@ class AssetInfo(Base):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||||
|
|
||||||
|
|
||||||
class AssetInfoMeta(Base):
|
class AssetInfoMeta(Base):
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
56
app/database/services/__init__.py
Normal file
56
app/database/services/__init__.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from .content import (
|
||||||
|
check_fs_asset_exists_quick,
|
||||||
|
compute_hash_and_dedup_for_cache_state,
|
||||||
|
ensure_seed_for_path,
|
||||||
|
ingest_fs_asset,
|
||||||
|
list_cache_states_with_asset_under_prefixes,
|
||||||
|
list_unhashed_candidates_under_prefixes,
|
||||||
|
list_verify_candidates_under_prefixes,
|
||||||
|
redirect_all_references_then_delete_asset,
|
||||||
|
touch_asset_infos_by_fs_path,
|
||||||
|
)
|
||||||
|
from .info import (
|
||||||
|
add_tags_to_asset_info,
|
||||||
|
create_asset_info_for_existing_asset,
|
||||||
|
delete_asset_info_by_id,
|
||||||
|
fetch_asset_info_and_asset,
|
||||||
|
fetch_asset_info_asset_and_tags,
|
||||||
|
get_asset_tags,
|
||||||
|
list_asset_infos_page,
|
||||||
|
list_tags_with_usage,
|
||||||
|
remove_tags_from_asset_info,
|
||||||
|
replace_asset_info_metadata_projection,
|
||||||
|
set_asset_info_preview,
|
||||||
|
set_asset_info_tags,
|
||||||
|
touch_asset_info_by_id,
|
||||||
|
update_asset_info_full,
|
||||||
|
)
|
||||||
|
from .queries import (
|
||||||
|
asset_exists_by_hash,
|
||||||
|
asset_info_exists_for_asset_id,
|
||||||
|
get_asset_by_hash,
|
||||||
|
get_asset_info_by_id,
|
||||||
|
get_cache_state_by_asset_id,
|
||||||
|
list_cache_states_by_asset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# queries
|
||||||
|
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||||
|
"get_cache_state_by_asset_id",
|
||||||
|
"list_cache_states_by_asset_id",
|
||||||
|
# info
|
||||||
|
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||||
|
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||||
|
"touch_asset_info_by_id", "delete_asset_info_by_id",
|
||||||
|
"add_tags_to_asset_info", "remove_tags_from_asset_info",
|
||||||
|
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
|
||||||
|
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
|
||||||
|
# content
|
||||||
|
"check_fs_asset_exists_quick", "ensure_seed_for_path",
|
||||||
|
"redirect_all_references_then_delete_asset",
|
||||||
|
"compute_hash_and_dedup_for_cache_state",
|
||||||
|
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
|
||||||
|
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
|
||||||
|
"list_cache_states_with_asset_under_prefixes",
|
||||||
|
]
|
||||||
746
app/database/services/content.py
Normal file
746
app/database/services/content.py
Normal file
@ -0,0 +1,746 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.dialects import postgresql as d_pg
|
||||||
|
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import noload
|
||||||
|
|
||||||
|
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||||
|
from ...storage import hashing as hashing_mod
|
||||||
|
from ..helpers import (
|
||||||
|
ensure_tags_exist,
|
||||||
|
remove_missing_tag_for_asset_id,
|
||||||
|
)
|
||||||
|
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||||
|
from ..timeutil import utcnow
|
||||||
|
from .info import replace_asset_info_metadata_projection
|
||||||
|
|
||||||
|
|
||||||
|
async def check_fs_asset_exists_quick(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
size_bytes: Optional[int] = None,
|
||||||
|
mtime_ns: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if a cache row exists for this absolute path and (optionally) mtime/size match."""
|
||||||
|
locator = os.path.abspath(file_path)
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
sa.select(sa.literal(True))
|
||||||
|
.select_from(AssetCacheState)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(AssetCacheState.file_path == locator)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
if mtime_ns is not None:
|
||||||
|
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||||
|
if size_bytes is not None:
|
||||||
|
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||||
|
|
||||||
|
if conds:
|
||||||
|
stmt = stmt.where(*conds)
|
||||||
|
|
||||||
|
row = (await session.execute(stmt)).first()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_seed_for_path(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
abs_path: str,
|
||||||
|
size_bytes: int,
|
||||||
|
mtime_ns: int,
|
||||||
|
info_name: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id."""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
now = utcnow()
|
||||||
|
|
||||||
|
state = (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetCacheState, Asset)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(AssetCacheState.file_path == locator)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if state:
|
||||||
|
state_row: AssetCacheState = state[0]
|
||||||
|
asset_row: Asset = state[1]
|
||||||
|
changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns)
|
||||||
|
if changed:
|
||||||
|
state_row.mtime_ns = int(mtime_ns)
|
||||||
|
state_row.needs_verify = True
|
||||||
|
if asset_row.size_bytes == 0 and size_bytes > 0:
|
||||||
|
asset_row.size_bytes = int(size_bytes)
|
||||||
|
return asset_row.id
|
||||||
|
|
||||||
|
# Create new asset (hash=NULL)
|
||||||
|
asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now)
|
||||||
|
session.add(asset)
|
||||||
|
await session.flush() # to get id
|
||||||
|
|
||||||
|
cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False)
|
||||||
|
session.add(cs)
|
||||||
|
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=info_name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(info)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Attach tags
|
||||||
|
want = normalize_tags(tags)
|
||||||
|
if want:
|
||||||
|
await ensure_tags_exist(session, want, tag_type="user")
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now)
|
||||||
|
for t in want
|
||||||
|
])
|
||||||
|
|
||||||
|
await session.flush()
|
||||||
|
return asset.id
|
||||||
|
|
||||||
|
|
||||||
|
async def redirect_all_references_then_delete_asset(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
duplicate_asset_id: str,
|
||||||
|
canonical_asset_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
|
||||||
|
|
||||||
|
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
|
||||||
|
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
|
||||||
|
- Otherwise, simply repoint the AssetInfo.asset_id.
|
||||||
|
- Always retarget AssetCacheState rows.
|
||||||
|
- Finally delete the duplicate Asset row.
|
||||||
|
"""
|
||||||
|
if duplicate_asset_id == canonical_asset_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
|
||||||
|
dup_infos = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
|
||||||
|
)
|
||||||
|
).unique().scalars().all()
|
||||||
|
|
||||||
|
for info in dup_infos:
|
||||||
|
# Try to find an existing collision on canonical
|
||||||
|
existing = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == canonical_asset_id,
|
||||||
|
AssetInfo.owner_id == info.owner_id,
|
||||||
|
AssetInfo.name == info.name,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalars().first()
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
# Merge metadata (prefer existing keys, fill gaps from duplicate)
|
||||||
|
merged_meta = dict(existing.user_metadata or {})
|
||||||
|
other_meta = info.user_metadata or {}
|
||||||
|
for k, v in other_meta.items():
|
||||||
|
if k not in merged_meta:
|
||||||
|
merged_meta[k] = v
|
||||||
|
if merged_meta != (existing.user_metadata or {}):
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=existing.id,
|
||||||
|
user_metadata=merged_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge tags (union)
|
||||||
|
existing_tags = {
|
||||||
|
t for (t,) in (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
from_tags = {
|
||||||
|
t for (t,) in (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
to_add = sorted(from_tags - existing_tags)
|
||||||
|
if to_add:
|
||||||
|
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||||
|
now = utcnow()
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
|
||||||
|
for t in to_add
|
||||||
|
])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Merge preview and times
|
||||||
|
if existing.preview_id is None and info.preview_id is not None:
|
||||||
|
existing.preview_id = info.preview_id
|
||||||
|
if info.last_access_time and (
|
||||||
|
existing.last_access_time is None or info.last_access_time > existing.last_access_time
|
||||||
|
):
|
||||||
|
existing.last_access_time = info.last_access_time
|
||||||
|
existing.updated_at = utcnow()
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
|
||||||
|
await session.delete(info)
|
||||||
|
await session.flush()
|
||||||
|
else:
|
||||||
|
# Simple retarget
|
||||||
|
info.asset_id = canonical_asset_id
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# 2) Repoint cache states and previews
|
||||||
|
await session.execute(
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == duplicate_asset_id)
|
||||||
|
.values(asset_id=canonical_asset_id)
|
||||||
|
)
|
||||||
|
await session.execute(
|
||||||
|
sa.update(AssetInfo)
|
||||||
|
.where(AssetInfo.preview_id == duplicate_asset_id)
|
||||||
|
.values(preview_id=canonical_asset_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3) Remove duplicate Asset
|
||||||
|
dup = await session.get(Asset, duplicate_asset_id)
|
||||||
|
if dup:
|
||||||
|
await session.delete(dup)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
async def compute_hash_and_dedup_for_cache_state(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
state_id: int,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Compute hash for the given cache state, deduplicate, and settle verify cases.
|
||||||
|
|
||||||
|
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
|
||||||
|
"""
|
||||||
|
state = await session.get(AssetCacheState, state_id)
|
||||||
|
if not state:
|
||||||
|
return None
|
||||||
|
|
||||||
|
path = state.file_path
|
||||||
|
try:
|
||||||
|
if not os.path.isfile(path):
|
||||||
|
# File vanished: drop the state. If the Asset was a seed (hash NULL)
|
||||||
|
# and has no other states, drop the Asset too.
|
||||||
|
asset = await session.get(Asset, state.asset_id)
|
||||||
|
await session.delete(state)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
if asset and asset.hash is None:
|
||||||
|
remaining = (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(sa.func.count())
|
||||||
|
.select_from(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == asset.id)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
if int(remaining or 0) == 0:
|
||||||
|
await session.delete(asset)
|
||||||
|
await session.flush()
|
||||||
|
return None
|
||||||
|
|
||||||
|
digest = await hashing_mod.blake3_hash(path)
|
||||||
|
new_hash = f"blake3:{digest}"
|
||||||
|
|
||||||
|
st = os.stat(path, follow_symlinks=True)
|
||||||
|
new_size = int(st.st_size)
|
||||||
|
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
# Current asset of this state
|
||||||
|
this_asset = await session.get(Asset, state.asset_id)
|
||||||
|
|
||||||
|
# If the state got orphaned somehow (race), just reattach appropriately.
|
||||||
|
if not this_asset:
|
||||||
|
canonical = (
|
||||||
|
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if canonical:
|
||||||
|
state.asset_id = canonical.id
|
||||||
|
else:
|
||||||
|
now = utcnow()
|
||||||
|
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||||
|
session.add(new_asset)
|
||||||
|
await session.flush()
|
||||||
|
state.asset_id = new_asset.id
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return state.asset_id
|
||||||
|
|
||||||
|
# 1) Seed asset case (hash is NULL): claim or merge into canonical
|
||||||
|
if this_asset.hash is None:
|
||||||
|
canonical = (
|
||||||
|
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
|
||||||
|
if canonical and canonical.id != this_asset.id:
|
||||||
|
# Merge seed asset into canonical (safe, collision-aware)
|
||||||
|
await redirect_all_references_then_delete_asset(
|
||||||
|
session,
|
||||||
|
duplicate_asset_id=this_asset.id,
|
||||||
|
canonical_asset_id=canonical.id,
|
||||||
|
)
|
||||||
|
# Refresh state after the merge
|
||||||
|
state = await session.get(AssetCacheState, state_id)
|
||||||
|
if state:
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return canonical.id
|
||||||
|
|
||||||
|
# No canonical: try to claim the hash; handle races with a SAVEPOINT
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
this_asset.hash = new_hash
|
||||||
|
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||||
|
this_asset.size_bytes = new_size
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
# Someone else claimed it concurrently; fetch canonical and merge
|
||||||
|
canonical = (
|
||||||
|
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if canonical and canonical.id != this_asset.id:
|
||||||
|
await redirect_all_references_then_delete_asset(
|
||||||
|
session,
|
||||||
|
duplicate_asset_id=this_asset.id,
|
||||||
|
canonical_asset_id=canonical.id,
|
||||||
|
)
|
||||||
|
state = await session.get(AssetCacheState, state_id)
|
||||||
|
if state:
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return canonical.id
|
||||||
|
# If we got here, the integrity error was not about hash uniqueness
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Claimed successfully
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return this_asset.id
|
||||||
|
|
||||||
|
# 2) Verify case for hashed assets
|
||||||
|
if this_asset.hash == new_hash:
|
||||||
|
# Content unchanged; tidy up sizes/mtime
|
||||||
|
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||||
|
this_asset.size_bytes = new_size
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return this_asset.id
|
||||||
|
|
||||||
|
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
|
||||||
|
canonical = (
|
||||||
|
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if canonical:
|
||||||
|
target_id = canonical.id
|
||||||
|
else:
|
||||||
|
now = utcnow()
|
||||||
|
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||||
|
session.add(new_asset)
|
||||||
|
await session.flush()
|
||||||
|
target_id = new_asset.id
|
||||||
|
|
||||||
|
state.asset_id = target_id
|
||||||
|
state.mtime_ns = mtime_ns
|
||||||
|
state.needs_verify = False
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
await session.flush()
|
||||||
|
return target_id
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# Propagate; caller records the error and continues the worker.
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
async def list_unhashed_candidates_under_prefixes(
|
||||||
|
session: AsyncSession, *, prefixes: Sequence[str]
|
||||||
|
) -> list[int]:
|
||||||
|
if not prefixes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||||
|
|
||||||
|
rows = (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetCacheState.id)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(Asset.hash.is_(None))
|
||||||
|
.where(sa.or_(*conds))
|
||||||
|
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
seen = set()
|
||||||
|
result: list[int] = []
|
||||||
|
for sid in rows:
|
||||||
|
st = await session.get(AssetCacheState, sid)
|
||||||
|
if st and st.asset_id not in seen:
|
||||||
|
seen.add(st.asset_id)
|
||||||
|
result.append(sid)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def list_verify_candidates_under_prefixes(
|
||||||
|
session: AsyncSession, *, prefixes: Sequence[str]
|
||||||
|
) -> Union[list[int], Sequence[int]]:
|
||||||
|
if not prefixes:
|
||||||
|
return []
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||||
|
|
||||||
|
return (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetCacheState.id)
|
||||||
|
.where(AssetCacheState.needs_verify.is_(True))
|
||||||
|
.where(sa.or_(*conds))
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
async def ingest_fs_asset(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
abs_path: str,
|
||||||
|
size_bytes: int,
|
||||||
|
mtime_ns: int,
|
||||||
|
mime_type: Optional[str] = None,
|
||||||
|
info_name: Optional[str] = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
preview_id: Optional[str] = None,
|
||||||
|
user_metadata: Optional[dict] = None,
|
||||||
|
tags: Sequence[str] = (),
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
require_existing_tags: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Idempotently upsert:
|
||||||
|
- Asset by content hash (create if missing)
|
||||||
|
- AssetCacheState(file_path) pointing to asset_id
|
||||||
|
- Optionally AssetInfo + tag links and metadata projection
|
||||||
|
Returns flags and ids.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
now = utcnow()
|
||||||
|
|
||||||
|
if preview_id:
|
||||||
|
if not await session.get(Asset, preview_id):
|
||||||
|
preview_id = None
|
||||||
|
|
||||||
|
out: dict[str, Any] = {
|
||||||
|
"asset_created": False,
|
||||||
|
"asset_updated": False,
|
||||||
|
"state_created": False,
|
||||||
|
"state_updated": False,
|
||||||
|
"asset_info_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1) Asset by hash
|
||||||
|
asset = (
|
||||||
|
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
async with session.begin_nested():
|
||||||
|
asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now)
|
||||||
|
session.add(asset)
|
||||||
|
await session.flush()
|
||||||
|
out["asset_created"] = True
|
||||||
|
else:
|
||||||
|
changed = False
|
||||||
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
|
asset.size_bytes = int(size_bytes)
|
||||||
|
changed = True
|
||||||
|
if mime_type and asset.mime_type != mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
out["asset_updated"] = True
|
||||||
|
|
||||||
|
# 2) AssetCacheState upsert by file_path (unique)
|
||||||
|
vals = {
|
||||||
|
"asset_id": asset.id,
|
||||||
|
"file_path": locator,
|
||||||
|
"mtime_ns": int(mtime_ns),
|
||||||
|
}
|
||||||
|
dialect = session.bind.dialect.name
|
||||||
|
if dialect == "sqlite":
|
||||||
|
ins = (
|
||||||
|
d_sqlite.insert(AssetCacheState)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||||
|
)
|
||||||
|
elif dialect == "postgresql":
|
||||||
|
ins = (
|
||||||
|
d_pg.insert(AssetCacheState)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||||
|
|
||||||
|
res = await session.execute(ins)
|
||||||
|
if int(res.rowcount or 0) > 0:
|
||||||
|
out["state_created"] = True
|
||||||
|
else:
|
||||||
|
upd = (
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.file_path == locator)
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetCacheState.asset_id != asset.id,
|
||||||
|
AssetCacheState.mtime_ns.is_(None),
|
||||||
|
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||||
|
)
|
||||||
|
res2 = await session.execute(upd)
|
||||||
|
if int(res2.rowcount or 0) > 0:
|
||||||
|
out["state_updated"] = True
|
||||||
|
|
||||||
|
# 3) Optional AssetInfo + tags + metadata
|
||||||
|
if info_name:
|
||||||
|
# upsert by (asset_id, owner_id, name)
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=info_name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=preview_id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(info)
|
||||||
|
await session.flush()
|
||||||
|
out["asset_info_id"] = info.id
|
||||||
|
except IntegrityError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
existing_info = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == info_name,
|
||||||
|
(AssetInfo.owner_id == owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalar_one_or_none()
|
||||||
|
if not existing_info:
|
||||||
|
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||||
|
|
||||||
|
if preview_id and existing_info.preview_id != preview_id:
|
||||||
|
existing_info.preview_id = preview_id
|
||||||
|
|
||||||
|
existing_info.updated_at = now
|
||||||
|
if existing_info.last_access_time < now:
|
||||||
|
existing_info.last_access_time = now
|
||||||
|
await session.flush()
|
||||||
|
out["asset_info_id"] = existing_info.id
|
||||||
|
|
||||||
|
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
if norm and out["asset_info_id"] is not None:
|
||||||
|
if not require_existing_tags:
|
||||||
|
await ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
existing_tag_names = set(
|
||||||
|
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||||
|
)
|
||||||
|
missing = [t for t in norm if t not in existing_tag_names]
|
||||||
|
if missing and require_existing_tags:
|
||||||
|
raise ValueError(f"Unknown tags: {missing}")
|
||||||
|
|
||||||
|
existing_links = set(
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||||
|
if to_add:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
tag_name=t,
|
||||||
|
origin=tag_origin,
|
||||||
|
added_at=now,
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
if out["asset_info_id"] is not None:
|
||||||
|
primary_path = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetCacheState.file_path)
|
||||||
|
.where(AssetCacheState.asset_id == asset.id)
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).scalars().first()
|
||||||
|
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
|
||||||
|
|
||||||
|
current_meta = existing_info.user_metadata or {}
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
if user_metadata is not None:
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
new_meta[k] = v
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
|
||||||
|
if new_meta != current_meta:
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
async def touch_asset_infos_by_fs_path(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
file_path: str,
|
||||||
|
ts: Optional[datetime] = None,
|
||||||
|
only_if_newer: bool = True,
|
||||||
|
) -> int:
|
||||||
|
locator = os.path.abspath(file_path)
|
||||||
|
ts = ts or utcnow()
|
||||||
|
|
||||||
|
stmt = sa.update(AssetInfo).where(
|
||||||
|
sa.exists(
|
||||||
|
sa.select(sa.literal(1))
|
||||||
|
.select_from(AssetCacheState)
|
||||||
|
.where(
|
||||||
|
AssetCacheState.asset_id == AssetInfo.asset_id,
|
||||||
|
AssetCacheState.file_path == locator,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if only_if_newer:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetInfo.last_access_time.is_(None),
|
||||||
|
AssetInfo.last_access_time < ts,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = stmt.values(last_access_time=ts)
|
||||||
|
|
||||||
|
res = await session.execute(stmt)
|
||||||
|
return int(res.rowcount or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_cache_states_with_asset_under_prefixes(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
prefixes: Sequence[str],
|
||||||
|
) -> list[tuple[AssetCacheState, Optional[str], int]]:
|
||||||
|
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
|
||||||
|
if not prefixes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
if not p:
|
||||||
|
continue
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base = base + os.sep
|
||||||
|
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||||
|
|
||||||
|
if not conds:
|
||||||
|
return []
|
||||||
|
|
||||||
|
rows = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||||
|
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||||
|
.where(sa.or_(*conds))
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||||
579
app/database/services/info.py
Normal file
579
app/database/services/info.py
Normal file
@ -0,0 +1,579 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import delete, func, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import contains_eager, noload
|
||||||
|
|
||||||
|
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||||
|
from ..helpers import (
|
||||||
|
apply_metadata_filter,
|
||||||
|
apply_tag_filters,
|
||||||
|
ensure_tags_exist,
|
||||||
|
project_kv,
|
||||||
|
visible_owner_clause,
|
||||||
|
)
|
||||||
|
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||||
|
from ..timeutil import utcnow
|
||||||
|
from .queries import get_asset_by_hash, get_cache_state_by_asset_id
|
||||||
|
|
||||||
|
|
||||||
|
async def list_asset_infos_page(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
|
name_contains: Optional[str] = None,
|
||||||
|
metadata_filter: Optional[dict] = None,
|
||||||
|
limit: int = 20,
|
||||||
|
offset: int = 0,
|
||||||
|
sort: str = "created_at",
|
||||||
|
order: str = "desc",
|
||||||
|
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||||
|
base = (
|
||||||
|
select(AssetInfo)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
)
|
||||||
|
|
||||||
|
if name_contains:
|
||||||
|
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
|
|
||||||
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
|
sort = (sort or "created_at").lower()
|
||||||
|
order = (order or "desc").lower()
|
||||||
|
sort_map = {
|
||||||
|
"name": AssetInfo.name,
|
||||||
|
"created_at": AssetInfo.created_at,
|
||||||
|
"updated_at": AssetInfo.updated_at,
|
||||||
|
"last_access_time": AssetInfo.last_access_time,
|
||||||
|
"size": Asset.size_bytes,
|
||||||
|
}
|
||||||
|
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||||
|
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||||
|
|
||||||
|
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||||
|
|
||||||
|
count_stmt = (
|
||||||
|
select(func.count())
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
)
|
||||||
|
if name_contains:
|
||||||
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
|
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||||
|
|
||||||
|
infos = (await session.execute(base)).unique().scalars().all()
|
||||||
|
|
||||||
|
id_list: list[str] = [i.id for i in infos]
|
||||||
|
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||||
|
if id_list:
|
||||||
|
rows = await session.execute(
|
||||||
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||||
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||||
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||||
|
)
|
||||||
|
for aid, tag_name in rows.all():
|
||||||
|
tag_map[aid].append(tag_name)
|
||||||
|
|
||||||
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_asset_info_and_asset(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||||
|
stmt = (
|
||||||
|
select(AssetInfo, Asset)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
)
|
||||||
|
row = await session.execute(stmt)
|
||||||
|
pair = row.first()
|
||||||
|
if not pair:
|
||||||
|
return None
|
||||||
|
return pair[0], pair[1]
|
||||||
|
|
||||||
|
|
||||||
|
async def fetch_asset_info_asset_and_tags(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||||
|
stmt = (
|
||||||
|
select(AssetInfo, Asset, Tag.name)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||||
|
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||||
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.order_by(Tag.name.asc())
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (await session.execute(stmt)).all()
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_info, first_asset, _ = rows[0]
|
||||||
|
tags: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for _info, _asset, tag_name in rows:
|
||||||
|
if tag_name and tag_name not in seen:
|
||||||
|
seen.add(tag_name)
|
||||||
|
tags.append(tag_name)
|
||||||
|
return first_info, first_asset, tags
|
||||||
|
|
||||||
|
|
||||||
|
async def create_asset_info_for_existing_asset(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
name: str,
|
||||||
|
user_metadata: Optional[dict] = None,
|
||||||
|
tags: Optional[Sequence[str]] = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> AssetInfo:
|
||||||
|
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||||
|
now = utcnow()
|
||||||
|
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
|
if not asset:
|
||||||
|
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||||
|
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async with session.begin_nested():
|
||||||
|
session.add(info)
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
existing = (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == name,
|
||||||
|
AssetInfo.owner_id == owner_id,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalars().first()
|
||||||
|
if not existing:
|
||||||
|
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
new_meta = dict(user_metadata or {})
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
state = await get_cache_state_by_asset_id(session, asset_id=asset.id)
|
||||||
|
if state and state.file_path:
|
||||||
|
computed_filename = compute_model_relative_filename(state.file_path)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
if new_meta:
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
await set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
async def set_asset_info_tags(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
) -> dict:
|
||||||
|
desired = normalize_tags(tags)
|
||||||
|
|
||||||
|
current = set(
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
to_add = [t for t in desired if t not in current]
|
||||||
|
to_remove = [t for t in current if t not in desired]
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||||
|
for t in to_add
|
||||||
|
])
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||||
|
|
||||||
|
|
||||||
|
async def update_asset_info_full(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
tags: Optional[Sequence[str]] = None,
|
||||||
|
user_metadata: Optional[dict] = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> AssetInfo:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
else:
|
||||||
|
info = asset_info_row
|
||||||
|
|
||||||
|
touched = False
|
||||||
|
if name is not None and name != info.name:
|
||||||
|
info.name = name
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id)
|
||||||
|
if state and state.file_path:
|
||||||
|
computed_filename = compute_model_relative_filename(state.file_path)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
|
||||||
|
if user_metadata is not None:
|
||||||
|
new_meta = dict(user_metadata)
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
else:
|
||||||
|
if computed_filename:
|
||||||
|
current_meta = info.user_metadata or {}
|
||||||
|
if current_meta.get("filename") != computed_filename:
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
await replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
await set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if touched and user_metadata is None:
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
async def replace_asset_info_metadata_projection(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
user_metadata: Optional[dict],
|
||||||
|
) -> None:
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info.user_metadata = user_metadata or {}
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
if not user_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetInfoMeta] = []
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
for r in project_kv(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetInfoMeta(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
async def touch_asset_info_by_id(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
ts: Optional[datetime] = None,
|
||||||
|
only_if_newer: bool = True,
|
||||||
|
) -> int:
|
||||||
|
ts = ts or utcnow()
|
||||||
|
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||||
|
if only_if_newer:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||||
|
)
|
||||||
|
stmt = stmt.values(last_access_time=ts)
|
||||||
|
res = await session.execute(stmt)
|
||||||
|
return int(res.rowcount or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
|
||||||
|
res = await session.execute(delete(AssetInfo).where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
))
|
||||||
|
return bool(res.rowcount)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_tags_to_asset_info(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
create_if_missing: bool = True,
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> dict:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"added": [], "already_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
if create_if_missing:
|
||||||
|
await ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
current = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
want = set(norm)
|
||||||
|
to_add = sorted(want - current)
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
async with session.begin_nested() as nested:
|
||||||
|
try:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tag_name=t,
|
||||||
|
origin=origin,
|
||||||
|
added_at=utcnow(),
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
await nested.rollback()
|
||||||
|
|
||||||
|
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
||||||
|
return {
|
||||||
|
"added": sorted(((after - current) & want)),
|
||||||
|
"already_present": sorted(want & current),
|
||||||
|
"total_tags": sorted(after),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def remove_tags_from_asset_info(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
) -> dict:
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": [], "not_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
existing = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
to_remove = sorted(set(t for t in norm if t in existing))
|
||||||
|
not_present = sorted(set(t for t in norm if t not in existing))
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
await session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(
|
||||||
|
AssetInfoTag.asset_info_id == asset_info_id,
|
||||||
|
AssetInfoTag.tag_name.in_(to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||||
|
|
||||||
|
|
||||||
|
async def list_tags_with_usage(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
prefix: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
include_zero: bool = True,
|
||||||
|
order: str = "count_desc",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[list[tuple[str, str, int]], int]:
|
||||||
|
counts_sq = (
|
||||||
|
select(
|
||||||
|
AssetInfoTag.tag_name.label("tag_name"),
|
||||||
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||||
|
)
|
||||||
|
.select_from(AssetInfoTag)
|
||||||
|
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
|
.group_by(AssetInfoTag.tag_name)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
q = (
|
||||||
|
select(
|
||||||
|
Tag.name,
|
||||||
|
Tag.tag_type,
|
||||||
|
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||||
|
)
|
||||||
|
.select_from(Tag)
|
||||||
|
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||||
|
|
||||||
|
if not include_zero:
|
||||||
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||||
|
|
||||||
|
if order == "name_asc":
|
||||||
|
q = q.order_by(Tag.name.asc())
|
||||||
|
else:
|
||||||
|
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||||
|
|
||||||
|
total_q = select(func.count()).select_from(Tag)
|
||||||
|
if prefix:
|
||||||
|
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||||
|
if not include_zero:
|
||||||
|
total_q = total_q.where(
|
||||||
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||||
|
total = (await session.execute(total_q)).scalar_one()
|
||||||
|
|
||||||
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||||
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
|
||||||
|
return [
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
await session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def set_asset_info_preview(
|
||||||
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
preview_asset_id: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
if preview_asset_id is None:
|
||||||
|
info.preview_id = None
|
||||||
|
else:
|
||||||
|
# validate preview asset exists
|
||||||
|
if not await session.get(Asset, preview_asset_id):
|
||||||
|
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||||
|
info.preview_id = preview_asset_id
|
||||||
|
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
await session.flush()
|
||||||
59
app/database/services/queries.py
Normal file
59
app/database/services/queries.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from ..models import Asset, AssetCacheState, AssetInfo
|
||||||
|
|
||||||
|
|
||||||
|
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||||
|
row = (
|
||||||
|
await session.execute(
|
||||||
|
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
||||||
|
return (
|
||||||
|
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
|
||||||
|
return await session.get(AssetInfo, asset_info_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
|
||||||
|
q = (
|
||||||
|
select(sa.literal(True))
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.where(AssetInfo.asset_id == asset_id)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return (await session.execute(q)).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
|
||||||
|
return (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == asset_id)
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
async def list_cache_states_by_asset_id(
|
||||||
|
session: AsyncSession, *, asset_id: str
|
||||||
|
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||||
|
return (
|
||||||
|
await session.execute(
|
||||||
|
select(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == asset_id)
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
@ -212,7 +212,6 @@ database_default_path = os.path.abspath(
|
|||||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||||
)
|
)
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||||
parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.")
|
|
||||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
|
|||||||
4
main.py
4
main.py
@ -279,11 +279,11 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
async def setup_database():
|
async def setup_database():
|
||||||
from app import init_db_engine, start_background_assets_scan
|
from app import init_db_engine, sync_seed_assets
|
||||||
|
|
||||||
await init_db_engine()
|
await init_db_engine()
|
||||||
if not args.disable_assets_autoscan:
|
if not args.disable_assets_autoscan:
|
||||||
await start_background_assets_scan()
|
await sync_seed_assets(["models", "input", "output"])
|
||||||
|
|
||||||
|
|
||||||
def start_comfyui(asyncio_loop=None):
|
def start_comfyui(asyncio_loop=None):
|
||||||
|
|||||||
@ -37,7 +37,7 @@ from app.model_manager import ModelFileManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from app.api.assets_routes import register_assets_system
|
from app import sync_seed_assets, register_assets_system
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
@ -629,6 +629,7 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
|
await sync_seed_assets(["models"])
|
||||||
with folder_paths.cache_helper:
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
|
|||||||
@ -118,6 +118,16 @@ async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, se
|
|||||||
assert rh2.status == 404
|
assert rh2.status == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
|
||||||
|
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||||
|
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||||
|
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
|
||||||
|
assert rh.status == 400
|
||||||
|
body = await rh.read()
|
||||||
|
assert body == b""
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
|
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
|
||||||
bogus = str(uuid.uuid4())
|
bogus = str(uuid.uuid4())
|
||||||
@ -166,12 +176,3 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a
|
|||||||
body = await r.json()
|
body = await r.json()
|
||||||
assert r.status == 400
|
assert r.status == 400
|
||||||
assert body["error"]["code"] == "INVALID_BODY"
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str):
|
|
||||||
# Invalid format
|
|
||||||
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3:
|
|
||||||
jb = await rh3.json()
|
|
||||||
assert rh3.status == 400
|
|
||||||
assert jb is None # HEAD request should not include "body" in response
|
|
||||||
|
|||||||
@ -66,23 +66,32 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s
|
|||||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
||||||
b1 = await r1.json()
|
b1 = await r1.json()
|
||||||
assert r1.status == 200, b1
|
assert r1.status == 200, b1
|
||||||
# normalized and deduplicated
|
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||||
assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"]
|
assert set(b1["added"]) == {"newtag", "beta"}
|
||||||
|
assert set(b1["already_present"]) == {"unit-tests"}
|
||||||
|
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||||
|
|
||||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||||
g = await rg.json()
|
g = await rg.json()
|
||||||
assert rg.status == 200
|
assert rg.status == 200
|
||||||
tags_now = set(g["tags"])
|
tags_now = set(g["tags"])
|
||||||
assert "newtag" in tags_now
|
assert {"newtag", "beta"}.issubset(tags_now)
|
||||||
assert "beta" in tags_now
|
|
||||||
|
|
||||||
# Remove a tag and a non-existent tag
|
# Remove a tag and a non-existent tag
|
||||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
||||||
b2 = await r2.json()
|
b2 = await r2.json()
|
||||||
assert r2.status == 200
|
assert r2.status == 200
|
||||||
assert "newtag" in b2["removed"]
|
assert set(b2["removed"]) == {"newtag"}
|
||||||
assert "does-not-exist" in b2["not_present"]
|
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||||
|
|
||||||
|
# Verify remaining tags after deletion
|
||||||
|
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||||
|
g2 = await rg2.json()
|
||||||
|
assert rg2.status == 200
|
||||||
|
tags_later = set(g2["tags"])
|
||||||
|
assert "newtag" not in tags_later
|
||||||
|
assert "beta" in tags_later # still present
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@ -206,7 +206,7 @@ async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_b
|
|||||||
body = await r.json()
|
body = await r.json()
|
||||||
assert r.status == 400
|
assert r.status == 400
|
||||||
assert body["error"]["code"] == "INVALID_BODY"
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
|
assert body["error"]["message"].startswith("unknown models category")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user