mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-09 22:14:34 +08:00
global refactoring; add support for Assets without the computed hash
This commit is contained in:
parent
934377ac1e
commit
bb9ed04758
@ -16,33 +16,44 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ASSETS: content identity (deduplicated by hash)
|
||||
# ASSETS: content identity
|
||||
op.create_table(
|
||||
"assets",
|
||||
sa.Column("hash", sa.String(length=256), primary_key=True),
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("hash", sa.String(length=256), nullable=True),
|
||||
sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("mime_type", sa.String(length=255), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
if op.get_bind().dialect.name == "postgresql":
|
||||
op.create_index(
|
||||
"uq_assets_hash_not_null",
|
||||
"assets",
|
||||
["hash"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("hash IS NOT NULL"),
|
||||
)
|
||||
else:
|
||||
op.create_index("uq_assets_hash", "assets", ["hash"], unique=True)
|
||||
op.create_index("ix_assets_mime_type", "assets", ["mime_type"])
|
||||
|
||||
# ASSETS_INFO: user-visible references (mutable metadata)
|
||||
# ASSETS_INFO: user-visible references
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_hash", "assets_info", ["asset_hash"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
@ -69,18 +80,19 @@ def upgrade() -> None:
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
# ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset
|
||||
# ASSET_CACHE_STATE: N:1 local cache rows per Asset
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
|
||||
op.create_table(
|
||||
@ -99,7 +111,7 @@ def upgrade() -> None:
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
|
||||
# Tags vocabulary for models
|
||||
# Tags vocabulary
|
||||
tags_table = sa.table(
|
||||
"tags",
|
||||
sa.column("name", sa.String(length=512)),
|
||||
@ -108,12 +120,10 @@ def upgrade() -> None:
|
||||
op.bulk_insert(
|
||||
tags_table,
|
||||
[
|
||||
# Root folder tags
|
||||
{"name": "models", "tag_type": "system"},
|
||||
{"name": "input", "tag_type": "system"},
|
||||
{"name": "output", "tag_type": "system"},
|
||||
|
||||
# Core tags
|
||||
{"name": "configs", "tag_type": "system"},
|
||||
{"name": "checkpoints", "tag_type": "system"},
|
||||
{"name": "loras", "tag_type": "system"},
|
||||
@ -132,12 +142,11 @@ def upgrade() -> None:
|
||||
{"name": "photomaker", "tag_type": "system"},
|
||||
{"name": "classifiers", "tag_type": "system"},
|
||||
|
||||
# Extra basic tags
|
||||
{"name": "encoder", "tag_type": "system"},
|
||||
{"name": "decoder", "tag_type": "system"},
|
||||
|
||||
# Special tags
|
||||
{"name": "missing", "tag_type": "system"},
|
||||
{"name": "rescan", "tag_type": "system"},
|
||||
],
|
||||
)
|
||||
|
||||
@ -149,8 +158,9 @@ def downgrade() -> None:
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_constraint("uq_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
@ -160,14 +170,18 @@ def downgrade() -> None:
|
||||
op.drop_index("ix_tags_tag_type", table_name="tags")
|
||||
op.drop_table("tags")
|
||||
|
||||
op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info")
|
||||
op.drop_constraint("uq_assets_info_asset_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_hash", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
if op.get_bind().dialect.name == "postgresql":
|
||||
op.drop_index("uq_assets_hash_not_null", table_name="assets")
|
||||
else:
|
||||
op.drop_index("uq_assets_hash", table_name="assets")
|
||||
op.drop_index("ix_assets_mime_type", table_name="assets")
|
||||
op.drop_table("assets")
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .assets_scanner import sync_seed_assets
|
||||
from .database.db import init_db_engine
|
||||
from .assets_scanner import start_background_assets_scan
|
||||
from .api.assets_routes import register_assets_system
|
||||
|
||||
|
||||
__all__ = ["init_db_engine", "start_background_assets_scan"]
|
||||
__all__ = ["init_db_engine", "sync_seed_assets", "register_assets_system"]
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
import contextlib
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from typing import Literal, Optional, Sequence
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .database.models import AssetInfo
|
||||
from .api import schemas_in
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
@ -139,14 +140,6 @@ def ensure_within_base(candidate: str, base: str) -> None:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
@ -172,3 +165,61 @@ def compute_model_relative_filename(file_path: str) -> Optional[str]:
|
||||
return None
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside) # normalize to POSIX style for portability
|
||||
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
|
||||
def prefixes_for_root(root: schemas_in.RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def new_scan_id(root: schemas_in.RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
import uuid
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
@ -12,7 +12,6 @@ import folder_paths
|
||||
from .. import assets_manager, assets_scanner, user_manager
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
UserManager: Optional[user_manager.UserManager] = None
|
||||
|
||||
@ -272,6 +271,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=owner_id,
|
||||
expected_asset_hash=spec.hash,
|
||||
)
|
||||
status = 201 if created.created_new else 200
|
||||
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||
@ -332,6 +332,29 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
|
||||
async def set_asset_preview(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
try:
|
||||
body = schemas_in.SetPreviewBody.model_validate(await request.json())
|
||||
except ValidationError as ve:
|
||||
return _validation_error_response("INVALID_BODY", ve)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.set_asset_preview(
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=body.preview_id,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except (PermissionError, ValueError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||
|
||||
@ -1,7 +1,15 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from typing import Any, Optional, Literal
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_validator, model_validator, conint
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
conint,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
@ -148,30 +156,12 @@ class TagsRemove(TagsAdd):
|
||||
pass
|
||||
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[Literal["models","input","output"]] = Field(default_factory=list)
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
@field_validator("roots", mode="before")
|
||||
@classmethod
|
||||
def _normalize_roots(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
items = [x.strip().lower() for x in v.split(",")]
|
||||
elif isinstance(v, list):
|
||||
items = []
|
||||
for x in v:
|
||||
if isinstance(x, str):
|
||||
items.extend([p.strip().lower() for p in x.split(",")])
|
||||
else:
|
||||
return []
|
||||
out = []
|
||||
seen = set()
|
||||
for r in items:
|
||||
if r in {"models","input","output"} and r not in seen:
|
||||
out.append(r)
|
||||
seen.add(r)
|
||||
return out
|
||||
|
||||
class ScheduleAssetScanBody(BaseModel):
|
||||
roots: list[RootType] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
@ -281,3 +271,22 @@ class UploadAssetSpec(BaseModel):
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
return self
|
||||
|
||||
|
||||
class SetPreviewBody(BaseModel):
|
||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
||||
preview_id: Optional[str] = None
|
||||
|
||||
@field_validator("preview_id", mode="before")
|
||||
@classmethod
|
||||
def _norm_uuid(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
try:
|
||||
uuid.UUID(s)
|
||||
except Exception:
|
||||
raise ValueError("preview_id must be a UUID")
|
||||
return s
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
@ -31,7 +32,7 @@ class AssetsList(BaseModel):
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str
|
||||
asset_hash: Optional[str]
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: Optional[datetime] = None
|
||||
@ -46,12 +47,12 @@ class AssetUpdated(BaseModel):
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str
|
||||
asset_hash: Optional[str]
|
||||
size: Optional[int] = None
|
||||
mime_type: Optional[str] = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_hash: Optional[str] = None
|
||||
preview_id: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
last_access_time: Optional[datetime] = None
|
||||
|
||||
@ -95,7 +96,6 @@ class TagsRemove(BaseModel):
|
||||
class AssetScanError(BaseModel):
|
||||
path: str
|
||||
message: str
|
||||
phase: Literal["fast", "slow"]
|
||||
at: Optional[str] = Field(None, description="ISO timestamp")
|
||||
|
||||
|
||||
@ -108,8 +108,6 @@ class AssetScanStatus(BaseModel):
|
||||
finished_at: Optional[str] = None
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
slow_queue_total: int = 0
|
||||
slow_queue_finished: int = 0
|
||||
file_errors: list[AssetScanError] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@ -4,38 +4,39 @@ import mimetypes
|
||||
import os
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.internal import async_to_sync
|
||||
|
||||
from .database.db import create_session
|
||||
from .storage import hashing
|
||||
from .database.services import (
|
||||
check_fs_asset_exists_quick,
|
||||
ingest_fs_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
list_asset_infos_page,
|
||||
update_asset_info_full,
|
||||
get_asset_tags,
|
||||
list_tags_with_usage,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
fetch_asset_info_and_asset,
|
||||
touch_asset_info_by_id,
|
||||
delete_asset_info_by_id,
|
||||
asset_exists_by_hash,
|
||||
get_asset_by_hash,
|
||||
create_asset_info_for_existing_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_info_by_id,
|
||||
list_cache_states_by_asset_hash,
|
||||
asset_info_exists_for_hash,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from ._assets_helpers import (
|
||||
get_name_and_tags_from_asset_path,
|
||||
ensure_within_base,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.db import create_session
|
||||
from .database.models import Asset
|
||||
from .database.services import (
|
||||
add_tags_to_asset_info,
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
check_fs_asset_exists_quick,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_asset_tags,
|
||||
ingest_fs_asset,
|
||||
list_asset_infos_page,
|
||||
list_cache_states_by_asset_id,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
set_asset_info_preview,
|
||||
touch_asset_info_by_id,
|
||||
touch_asset_infos_by_fs_path,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .storage import hashing
|
||||
|
||||
|
||||
async def asset_exists(*, asset_hash: str) -> bool:
|
||||
@ -44,29 +45,21 @@ async def asset_exists(*, asset_hash: str) -> bool:
|
||||
|
||||
|
||||
def populate_db_with_asset(file_path: str, tags: Optional[list[str]] = None) -> None:
|
||||
if not args.enable_model_processing:
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
if tags is None:
|
||||
tags = []
|
||||
try:
|
||||
asset_name, path_tags = get_name_and_tags_from_asset_path(file_path)
|
||||
async_to_sync.AsyncToSyncConverter.run_async_in_thread(
|
||||
add_local_asset,
|
||||
tags=list(dict.fromkeys([*path_tags, *tags])),
|
||||
file_name=asset_name,
|
||||
file_path=file_path,
|
||||
)
|
||||
except ValueError as e:
|
||||
logging.warning("Skipping non-asset path %s: %s", file_path, e)
|
||||
|
||||
|
||||
async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> None:
|
||||
"""Adds a local asset to the DB. If already present and unchanged, does nothing.
|
||||
|
||||
Notes:
|
||||
- Uses absolute path as the canonical locator for the cache backend.
|
||||
- Computes BLAKE3 only when the fast existence check indicates it's needed.
|
||||
- This function ensures the identity row and seeds mtime in asset_cache_state.
|
||||
"""
|
||||
abs_path = os.path.abspath(file_path)
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(abs_path)
|
||||
if not size_bytes:
|
||||
@ -132,7 +125,7 @@ async def list_assets(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
@ -156,16 +149,17 @@ async def get_asset(*, asset_info_id: str, owner_id: str = "") -> schemas_out.As
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
preview_hash=info.preview_hash,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
@ -176,20 +170,13 @@ async def resolve_asset_content_for_download(
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
||||
Also touches last_access_time (only_if_newer).
|
||||
Raises:
|
||||
ValueError if AssetInfo cannot be found
|
||||
FileNotFoundError if file for Asset cannot be found
|
||||
"""
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = await list_cache_states_by_asset_hash(session, asset_hash=info.asset_hash)
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = ""
|
||||
for s in states:
|
||||
if s and s.file_path and os.path.isfile(s.file_path):
|
||||
@ -214,16 +201,6 @@ async def upload_asset_from_temp_path(
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: Optional[str] = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Finalize an uploaded temp file:
|
||||
- compute blake3 hash
|
||||
- if expected_asset_hash provided, verify equality (400 on mismatch at caller)
|
||||
- if an Asset with the same hash exists: discard temp, create AssetInfo only (no write)
|
||||
- else resolve destination from tags and atomically move into place
|
||||
- ingest into DB (assets, locator state, asset_info + tags)
|
||||
Returns a populated AssetCreated payload.
|
||||
"""
|
||||
|
||||
try:
|
||||
digest = await hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
@ -233,7 +210,6 @@ async def upload_asset_from_temp_path(
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
# Fast path: content already known --> no writes, just create a reference
|
||||
async with await create_session() as session:
|
||||
existing = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
@ -257,43 +233,37 @@ async def upload_asset_from_temp_path(
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_hash=info.preview_hash,
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
# Resolve destination (only for truly new content)
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
# Decide filename
|
||||
desired_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, desired_name))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
# Content type based on final name
|
||||
content_type = mimetypes.guess_type(desired_name, strict=False)[0] or "application/octet-stream"
|
||||
|
||||
# Atomic move into place
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
# Stat final file
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
# Ingest + build response
|
||||
async with await create_session() as session:
|
||||
result = await ingest_fs_asset(
|
||||
session,
|
||||
@ -304,7 +274,7 @@ async def upload_asset_from_temp_path(
|
||||
mime_type=content_type,
|
||||
info_name=os.path.basename(dest_abs),
|
||||
owner_id=owner_id,
|
||||
preview_hash=None,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
@ -324,12 +294,12 @@ async def upload_asset_from_temp_path(
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_hash=info.preview_hash,
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
@ -367,38 +337,74 @@ async def update_asset(
|
||||
return schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
"""Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default),
|
||||
delete the Asset row as well and remove all cached files recorded for that asset_hash.
|
||||
"""
|
||||
async def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_hash = info_row.asset_hash if info_row else None
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
await set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
await session.commit()
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_hash:
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash)
|
||||
still_exists = await asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash)
|
||||
states = await list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
asset_row = await session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
await session.delete(asset_row)
|
||||
|
||||
@ -439,12 +445,12 @@ async def create_asset_from_hash(
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset_hash,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_hash=info.preview_hash,
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
|
||||
@ -1,52 +1,55 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Literal, Optional, Sequence
|
||||
from typing import Literal, Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import folder_paths
|
||||
|
||||
from . import assets_manager
|
||||
from .api import schemas_out
|
||||
from ._assets_helpers import get_comfy_models_folders
|
||||
from ._assets_helpers import (
|
||||
collect_models_files,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
list_tree,
|
||||
new_scan_id,
|
||||
prefixes_for_root,
|
||||
ts_to_iso,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from .database.db import create_session
|
||||
from .database.helpers import (
|
||||
add_missing_tag_for_asset_id,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from .database.models import Asset, AssetCacheState, AssetInfo
|
||||
from .database.services import (
|
||||
check_fs_asset_exists_quick,
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
ensure_seed_for_path,
|
||||
list_cache_states_by_asset_id,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
add_missing_tag_for_asset_hash,
|
||||
remove_missing_tag_for_asset_hash,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
)
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
SLOW_HASH_CONCURRENCY = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanProgress:
|
||||
scan_id: str
|
||||
root: RootType
|
||||
root: schemas_in.RootType
|
||||
status: Literal["scheduled", "running", "completed", "failed", "cancelled"] = "scheduled"
|
||||
scheduled_at: float = field(default_factory=lambda: time.time())
|
||||
started_at: Optional[float] = None
|
||||
finished_at: Optional[float] = None
|
||||
|
||||
discovered: int = 0
|
||||
processed: int = 0
|
||||
slow_queue_total: int = 0
|
||||
slow_queue_finished: int = 0
|
||||
file_errors: list[dict] = field(default_factory=list) # {"path","message","phase","at"}
|
||||
|
||||
# Internal diagnostics for logs
|
||||
_fast_total_seen: int = 0
|
||||
_fast_clean: int = 0
|
||||
file_errors: list[dict] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -56,18 +59,14 @@ class SlowQueueState:
|
||||
closed: bool = False
|
||||
|
||||
|
||||
RUNNING_TASKS: dict[RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[RootType, ScanProgress] = {}
|
||||
SLOW_STATE_BY_ROOT: dict[RootType, SlowQueueState] = {}
|
||||
|
||||
|
||||
async def start_background_assets_scan():
|
||||
await fast_reconcile_and_kickoff(progress_cb=_console_cb)
|
||||
RUNNING_TASKS: dict[schemas_in.RootType, asyncio.Task] = {}
|
||||
PROGRESS_BY_ROOT: dict[schemas_in.RootType, ScanProgress] = {}
|
||||
SLOW_STATE_BY_ROOT: dict[schemas_in.RootType, SlowQueueState] = {}
|
||||
|
||||
|
||||
def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
scans = []
|
||||
for root in ALLOWED_ROOTS:
|
||||
for root in schemas_in.ALLOWED_ROOTS:
|
||||
prog = PROGRESS_BY_ROOT.get(root)
|
||||
if not prog:
|
||||
continue
|
||||
@ -75,83 +74,65 @@ def current_statuses() -> schemas_out.AssetScanStatusResponse:
|
||||
return schemas_out.AssetScanStatusResponse(scans=scans)
|
||||
|
||||
|
||||
async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusResponse:
|
||||
"""Schedule scans for the provided roots; returns progress snapshots.
|
||||
|
||||
Rules:
|
||||
- Only roots in {models, input, output} are accepted.
|
||||
- If a root is already scanning, we do NOT enqueue another one. Status returned as-is.
|
||||
- Otherwise a new task is created and started immediately.
|
||||
- Files with zero size are skipped.
|
||||
"""
|
||||
normalized: list[RootType] = []
|
||||
seen = set()
|
||||
for r in roots or []:
|
||||
rr = r.strip().lower()
|
||||
if rr in ALLOWED_ROOTS and rr not in seen:
|
||||
normalized.append(rr) # type: ignore
|
||||
seen.add(rr)
|
||||
if not normalized:
|
||||
normalized = list(ALLOWED_ROOTS) # schedule all by default
|
||||
|
||||
async def schedule_scans(roots: list[schemas_in.RootType]) -> schemas_out.AssetScanStatusResponse:
|
||||
results: list[ScanProgress] = []
|
||||
for root in normalized:
|
||||
for root in roots:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
results.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
||||
prog = ScanProgress(scan_id=new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
SLOW_STATE_BY_ROOT[root] = SlowQueueState(queue=asyncio.Queue())
|
||||
state = SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
RUNNING_TASKS[root] = asyncio.create_task(
|
||||
_pipeline_for_root(root, prog, progress_cb=None),
|
||||
_run_hash_verify_pipeline(root, prog, state),
|
||||
name=f"asset-scan:{root}",
|
||||
)
|
||||
results.append(prog)
|
||||
return _status_response_for(results)
|
||||
|
||||
|
||||
async def fast_reconcile_and_kickoff(
|
||||
roots: Optional[Sequence[str]] = None,
|
||||
*,
|
||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None,
|
||||
) -> schemas_out.AssetScanStatusResponse:
|
||||
"""
|
||||
Startup helper: do the fast pass now (so we know queue size),
|
||||
start slow hashing in the background, return immediately.
|
||||
"""
|
||||
normalized = [*ALLOWED_ROOTS] if not roots else [r for r in roots if r in ALLOWED_ROOTS]
|
||||
snaps: list[ScanProgress] = []
|
||||
|
||||
for root in normalized:
|
||||
if root in RUNNING_TASKS and not RUNNING_TASKS[root].done():
|
||||
snaps.append(PROGRESS_BY_ROOT[root])
|
||||
continue
|
||||
|
||||
prog = ScanProgress(scan_id=_new_scan_id(root), root=root, status="scheduled")
|
||||
PROGRESS_BY_ROOT[root] = prog
|
||||
state = SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
|
||||
prog.status = "running"
|
||||
prog.started_at = time.time()
|
||||
async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
|
||||
for r in roots:
|
||||
try:
|
||||
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
|
||||
except Exception as e:
|
||||
_append_error(prog, phase="fast", path="", message=str(e))
|
||||
prog.status = "failed"
|
||||
prog.finished_at = time.time()
|
||||
LOGGER.exception("Fast reconcile failed for %s", root)
|
||||
snaps.append(prog)
|
||||
await _fast_db_consistency_pass(r)
|
||||
except Exception as ex:
|
||||
LOGGER.exception("fast DB reconciliation failed for %s: %s", r, ex)
|
||||
|
||||
paths: list[str] = []
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
st = os.stat(p, follow_symlinks=True)
|
||||
if not int(st.st_size or 0):
|
||||
continue
|
||||
size_bytes = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
name, tags = get_name_and_tags_from_asset_path(p)
|
||||
await _seed_one_async(p, size_bytes, mtime_ns, name, tags)
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
|
||||
RUNNING_TASKS[root] = asyncio.create_task(
|
||||
_await_workers_then_finish(root, prog, state, progress_cb=progress_cb),
|
||||
name=f"asset-hash:{root}",
|
||||
|
||||
async def _seed_one_async(p: str, size_bytes: int, mtime_ns: int, name: str, tags: list[str]) -> None:
|
||||
async with await create_session() as sess:
|
||||
await ensure_seed_for_path(
|
||||
sess,
|
||||
abs_path=p,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
info_name=name,
|
||||
tags=tags,
|
||||
owner_id="",
|
||||
)
|
||||
snaps.append(prog)
|
||||
return _status_response_for(snaps)
|
||||
await sess.commit()
|
||||
|
||||
|
||||
def _status_response_for(progresses: list[ScanProgress]) -> schemas_out.AssetScanStatusResponse:
|
||||
@ -163,18 +144,15 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
|
||||
scan_id=progress.scan_id,
|
||||
root=progress.root,
|
||||
status=progress.status,
|
||||
scheduled_at=_ts_to_iso(progress.scheduled_at),
|
||||
started_at=_ts_to_iso(progress.started_at),
|
||||
finished_at=_ts_to_iso(progress.finished_at),
|
||||
scheduled_at=ts_to_iso(progress.scheduled_at),
|
||||
started_at=ts_to_iso(progress.started_at),
|
||||
finished_at=ts_to_iso(progress.finished_at),
|
||||
discovered=progress.discovered,
|
||||
processed=progress.processed,
|
||||
slow_queue_total=progress.slow_queue_total,
|
||||
slow_queue_finished=progress.slow_queue_finished,
|
||||
file_errors=[
|
||||
schemas_out.AssetScanError(
|
||||
path=e.get("path", ""),
|
||||
message=e.get("message", ""),
|
||||
phase=e.get("phase", "slow"),
|
||||
at=e.get("at"),
|
||||
)
|
||||
for e in (progress.file_errors or [])
|
||||
@ -182,27 +160,100 @@ def _scan_progress_to_scan_status_model(progress: ScanProgress) -> schemas_out.A
|
||||
)
|
||||
|
||||
|
||||
async def _pipeline_for_root(
|
||||
root: RootType,
|
||||
prog: ScanProgress,
|
||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
||||
) -> None:
|
||||
state = SLOW_STATE_BY_ROOT.get(root) or SlowQueueState(queue=asyncio.Queue())
|
||||
SLOW_STATE_BY_ROOT[root] = state
|
||||
async def _refresh_verify_flags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||
"""Fast pass to mark verify candidates by comparing stored mtime_ns with on-disk mtime."""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||
|
||||
async with await create_session() as sess:
|
||||
rows = (
|
||||
await sess.execute(
|
||||
sa.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
Asset.hash,
|
||||
AssetCacheState.file_path,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
)
|
||||
).all()
|
||||
|
||||
to_set = []
|
||||
to_clear = []
|
||||
for sid, mtime_db, needs_verify, a_hash, fp in rows:
|
||||
try:
|
||||
st = os.stat(fp, follow_symlinks=True)
|
||||
except OSError:
|
||||
# Missing files are handled by missing-tag reconciliation later.
|
||||
continue
|
||||
|
||||
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
if a_hash is not None:
|
||||
if mtime_db is None or int(mtime_db) != int(actual_mtime_ns):
|
||||
if not needs_verify:
|
||||
to_set.append(sid)
|
||||
else:
|
||||
if needs_verify:
|
||||
to_clear.append(sid)
|
||||
|
||||
if to_set:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear:
|
||||
await sess.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
await sess.commit()
|
||||
|
||||
|
||||
async def _run_hash_verify_pipeline(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
prog.status = "running"
|
||||
prog.started_at = time.time()
|
||||
|
||||
try:
|
||||
await _reconcile_missing_tags_for_root(root, prog)
|
||||
await _fast_reconcile_into_queue(root, prog, state, progress_cb=progress_cb)
|
||||
_start_slow_workers(root, prog, state, progress_cb=progress_cb)
|
||||
await _await_workers_then_finish(root, prog, state, progress_cb=progress_cb)
|
||||
prefixes = prefixes_for_root(root)
|
||||
|
||||
await _refresh_verify_flags_for_root(root, prog)
|
||||
|
||||
# collect candidates from DB
|
||||
async with await create_session() as sess:
|
||||
verify_ids = await list_verify_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
unhashed_ids = await list_unhashed_candidates_under_prefixes(sess, prefixes=prefixes)
|
||||
# dedupe: prioritize verification first
|
||||
seen = set()
|
||||
ordered: list[int] = []
|
||||
for lst in (verify_ids, unhashed_ids):
|
||||
for sid in lst:
|
||||
if sid not in seen:
|
||||
seen.add(sid); ordered.append(sid)
|
||||
|
||||
prog.discovered = len(ordered)
|
||||
|
||||
# queue up work
|
||||
for sid in ordered:
|
||||
await state.queue.put(sid)
|
||||
state.closed = True
|
||||
_start_state_workers(root, prog, state)
|
||||
await _await_state_workers_then_finish(root, prog, state)
|
||||
except asyncio.CancelledError:
|
||||
prog.status = "cancelled"
|
||||
raise
|
||||
except Exception as exc:
|
||||
_append_error(prog, phase="slow", path="", message=str(exc))
|
||||
_append_error(prog, path="", message=str(exc))
|
||||
prog.status = "failed"
|
||||
prog.finished_at = time.time()
|
||||
LOGGER.exception("Asset scan failed for %s", root)
|
||||
@ -210,110 +261,13 @@ async def _pipeline_for_root(
|
||||
RUNNING_TASKS.pop(root, None)
|
||||
|
||||
|
||||
async def _fast_reconcile_into_queue(
|
||||
root: RootType,
|
||||
prog: ScanProgress,
|
||||
state: SlowQueueState,
|
||||
*,
|
||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
||||
) -> None:
|
||||
async def _reconcile_missing_tags_for_root(root: schemas_in.RootType, prog: ScanProgress) -> None:
|
||||
"""
|
||||
Enumerate files, set 'discovered' to total files seen, increment 'processed' for fast-matched files,
|
||||
and queue the rest for slow hashing.
|
||||
"""
|
||||
if root == "models":
|
||||
files = _collect_models_files()
|
||||
preset_discovered = _count_nonzero_in_list(files)
|
||||
files_iter = asyncio.Queue()
|
||||
for p in files:
|
||||
await files_iter.put(p)
|
||||
await files_iter.put(None) # sentinel for our local draining loop
|
||||
elif root == "input":
|
||||
base = folder_paths.get_input_directory()
|
||||
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
|
||||
files_iter = await _queue_tree_files(base)
|
||||
elif root == "output":
|
||||
base = folder_paths.get_output_directory()
|
||||
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
|
||||
files_iter = await _queue_tree_files(base)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported root: {root}")
|
||||
Detect missing files quickly and toggle 'missing' tag per asset_id.
|
||||
|
||||
prog.discovered = int(preset_discovered or 0)
|
||||
|
||||
queued = 0
|
||||
checked = 0
|
||||
clean = 0
|
||||
|
||||
async with await create_session() as sess:
|
||||
while True:
|
||||
item = await files_iter.get()
|
||||
files_iter.task_done()
|
||||
if item is None:
|
||||
break
|
||||
|
||||
abs_path = item
|
||||
checked += 1
|
||||
|
||||
# Stat; skip empty/unreadable
|
||||
try:
|
||||
st = os.stat(abs_path, follow_symlinks=True)
|
||||
if not st.st_size:
|
||||
continue
|
||||
size_bytes = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
except OSError as e:
|
||||
_append_error(prog, phase="fast", path=abs_path, message=str(e))
|
||||
continue
|
||||
|
||||
try:
|
||||
known = await check_fs_asset_exists_quick(
|
||||
sess,
|
||||
file_path=abs_path,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
)
|
||||
except Exception as e:
|
||||
_append_error(prog, phase="fast", path=abs_path, message=str(e))
|
||||
known = False
|
||||
|
||||
if known:
|
||||
clean += 1
|
||||
prog.processed += 1
|
||||
else:
|
||||
await state.queue.put(abs_path)
|
||||
queued += 1
|
||||
prog.slow_queue_total += 1
|
||||
|
||||
if progress_cb:
|
||||
progress_cb(root, "fast", prog.processed, False, {
|
||||
"checked": checked,
|
||||
"clean": clean,
|
||||
"queued": queued,
|
||||
"discovered": prog.discovered,
|
||||
})
|
||||
|
||||
prog._fast_total_seen = checked
|
||||
prog._fast_clean = clean
|
||||
|
||||
if progress_cb:
|
||||
progress_cb(root, "fast", prog.processed, True, {
|
||||
"checked": checked,
|
||||
"clean": clean,
|
||||
"queued": queued,
|
||||
"discovered": prog.discovered,
|
||||
})
|
||||
state.closed = True
|
||||
|
||||
|
||||
async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -> None:
|
||||
"""
|
||||
Logic for detecting missing Assets files:
|
||||
- Clear 'missing' only if at least one cached path passes fast check:
|
||||
exists AND mtime_ns matches AND size matches.
|
||||
- Otherwise set 'missing'.
|
||||
Files that exist but fail fast check will be slow-hashed by the normal pipeline,
|
||||
and ingest_fs_asset will clear 'missing' if they truly match.
|
||||
Rules:
|
||||
- Only hashed assets (assets.hash != NULL) participate in missing tagging.
|
||||
- We consider ALL cache states of the asset (across roots) before tagging.
|
||||
"""
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
@ -326,232 +280,217 @@ async def _reconcile_missing_tags_for_root(root: RootType, prog: ScanProgress) -
|
||||
|
||||
try:
|
||||
async with await create_session() as sess:
|
||||
# state + hash + size for the current root
|
||||
rows = await list_cache_states_with_asset_under_prefixes(sess, prefixes=bases)
|
||||
|
||||
by_hash: dict[str, dict[str, bool]] = {} # {hash: {"any_fast_ok": bool}}
|
||||
for state, size_db in rows:
|
||||
h = state.asset_hash
|
||||
acc = by_hash.get(h)
|
||||
# Track fast_ok within the scanned root and whether the asset is hashed
|
||||
by_asset: dict[str, dict[str, bool]] = {}
|
||||
for state, a_hash, size_db in rows:
|
||||
aid = state.asset_id
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"any_fast_ok": False}
|
||||
by_hash[h] = acc
|
||||
acc = {"any_fast_ok_here": False, "hashed": (a_hash is not None), "size_db": int(size_db or 0)}
|
||||
by_asset[aid] = acc
|
||||
try:
|
||||
st = os.stat(state.file_path, follow_symlinks=True)
|
||||
actual_mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
fast_ok = False
|
||||
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
|
||||
if int(size_db) > 0 and int(st.st_size) == int(size_db):
|
||||
fast_ok = True
|
||||
if acc["hashed"]:
|
||||
if state.mtime_ns is not None and int(state.mtime_ns) == int(actual_mtime_ns):
|
||||
if int(acc["size_db"]) > 0 and int(st.st_size) == int(acc["size_db"]):
|
||||
fast_ok = True
|
||||
if fast_ok:
|
||||
acc["any_fast_ok"] = True
|
||||
acc["any_fast_ok_here"] = True
|
||||
except FileNotFoundError:
|
||||
pass # not fast_ok
|
||||
pass
|
||||
except OSError as e:
|
||||
_append_error(prog, phase="fast", path=state.file_path, message=str(e))
|
||||
_append_error(prog, path=state.file_path, message=str(e))
|
||||
|
||||
for h, acc in by_hash.items():
|
||||
# Decide per asset, considering ALL its states (not just this root)
|
||||
for aid, acc in by_asset.items():
|
||||
try:
|
||||
if acc["any_fast_ok"]:
|
||||
await remove_missing_tag_for_asset_hash(sess, asset_hash=h)
|
||||
if not acc["hashed"]:
|
||||
# Never tag seed assets as missing
|
||||
continue
|
||||
|
||||
any_fast_ok_global = acc["any_fast_ok_here"]
|
||||
if not any_fast_ok_global:
|
||||
# Check other states outside this root
|
||||
others = await list_cache_states_by_asset_id(sess, asset_id=aid)
|
||||
for st in others:
|
||||
try:
|
||||
s = os.stat(st.file_path, follow_symlinks=True)
|
||||
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
|
||||
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
|
||||
if acc["size_db"] > 0 and int(s.st_size) == acc["size_db"]:
|
||||
any_fast_ok_global = True
|
||||
break
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
if any_fast_ok_global:
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
else:
|
||||
await add_missing_tag_for_asset_hash(sess, asset_hash=h, origin="automatic")
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
except Exception as ex:
|
||||
_append_error(prog, phase="fast", path="", message=f"reconcile {h[:18]}: {ex}")
|
||||
_append_error(prog, path="", message=f"reconcile {aid[:8]}: {ex}")
|
||||
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
_append_error(prog, phase="fast", path="", message=f"reconcile failed: {e}")
|
||||
_append_error(prog, path="", message=f"reconcile failed: {e}")
|
||||
|
||||
|
||||
def _start_slow_workers(
|
||||
root: RootType,
|
||||
prog: ScanProgress,
|
||||
state: SlowQueueState,
|
||||
*,
|
||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
||||
) -> None:
|
||||
def _start_state_workers(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
if state.workers:
|
||||
return
|
||||
|
||||
async def _worker(_worker_id: int):
|
||||
async def _worker(_wid: int):
|
||||
while True:
|
||||
item = await state.queue.get()
|
||||
sid = await state.queue.get()
|
||||
try:
|
||||
if item is None:
|
||||
if sid is None:
|
||||
return
|
||||
try:
|
||||
await asyncio.to_thread(assets_manager.populate_db_with_asset, item)
|
||||
except Exception as e:
|
||||
_append_error(prog, phase="slow", path=item, message=str(e))
|
||||
async with await create_session() as sess:
|
||||
# Optional: fetch path for better error messages
|
||||
st = await sess.get(AssetCacheState, sid)
|
||||
try:
|
||||
await compute_hash_and_dedup_for_cache_state(sess, state_id=sid)
|
||||
await sess.commit()
|
||||
except Exception as e:
|
||||
path = st.file_path if st else f"state:{sid}"
|
||||
_append_error(prog, path=path, message=str(e))
|
||||
raise
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Slow queue finished for this item; also counts toward overall processed
|
||||
prog.slow_queue_finished += 1
|
||||
prog.processed += 1
|
||||
if progress_cb:
|
||||
progress_cb(root, "slow", prog.processed, False, {
|
||||
"slow_queue_finished": prog.slow_queue_finished,
|
||||
"slow_queue_total": prog.slow_queue_total,
|
||||
})
|
||||
finally:
|
||||
state.queue.task_done()
|
||||
|
||||
state.workers = [asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}") for i in range(SLOW_HASH_CONCURRENCY)]
|
||||
state.workers = [
|
||||
asyncio.create_task(_worker(i), name=f"asset-hash:{root}:{i}")
|
||||
for i in range(SLOW_HASH_CONCURRENCY)
|
||||
]
|
||||
|
||||
async def _close_when_empty():
|
||||
# When the fast phase closed the queue, push sentinels to end workers
|
||||
async def _close_when_ready():
|
||||
while not state.closed:
|
||||
await asyncio.sleep(0.05)
|
||||
for _ in range(SLOW_HASH_CONCURRENCY):
|
||||
await state.queue.put(None)
|
||||
|
||||
asyncio.create_task(_close_when_empty())
|
||||
asyncio.create_task(_close_when_ready())
|
||||
|
||||
|
||||
async def _await_workers_then_finish(
|
||||
root: RootType,
|
||||
prog: ScanProgress,
|
||||
state: SlowQueueState,
|
||||
*,
|
||||
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]],
|
||||
) -> None:
|
||||
async def _await_state_workers_then_finish(root: schemas_in.RootType, prog: ScanProgress, state: SlowQueueState) -> None:
|
||||
if state.workers:
|
||||
await asyncio.gather(*state.workers, return_exceptions=True)
|
||||
await _reconcile_missing_tags_for_root(root, prog)
|
||||
prog.finished_at = time.time()
|
||||
prog.status = "completed"
|
||||
if progress_cb:
|
||||
progress_cb(root, "slow", prog.processed, True, {
|
||||
"slow_queue_finished": prog.slow_queue_finished,
|
||||
"slow_queue_total": prog.slow_queue_total,
|
||||
})
|
||||
|
||||
|
||||
def _collect_models_files() -> list[str]:
|
||||
"""Collect absolute file paths from configured model buckets under models_dir."""
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
# ensure within allowed bases
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
|
||||
def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int:
|
||||
if not os.path.isdir(base_abs):
|
||||
return 0
|
||||
total = 0
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
if not only_nonzero:
|
||||
total += len(filenames)
|
||||
else:
|
||||
for name in filenames:
|
||||
with contextlib.suppress(OSError):
|
||||
st = os.stat(os.path.join(dirpath, name), follow_symlinks=True)
|
||||
if st.st_size:
|
||||
total += 1
|
||||
return total
|
||||
|
||||
|
||||
def _count_nonzero_in_list(paths: list[str]) -> int:
|
||||
cnt = 0
|
||||
for p in paths:
|
||||
with contextlib.suppress(OSError):
|
||||
st = os.stat(p, follow_symlinks=True)
|
||||
if st.st_size:
|
||||
cnt += 1
|
||||
return cnt
|
||||
|
||||
|
||||
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
|
||||
"""
|
||||
Walk base_dir in a worker thread and return a queue prefilled with all paths,
|
||||
terminated by a single None sentinel for the draining loop in fast reconcile.
|
||||
"""
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
await q.put(None)
|
||||
return q
|
||||
|
||||
def _walk_list():
|
||||
paths: list[str] = []
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
paths.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return paths
|
||||
|
||||
for p in await asyncio.to_thread(_walk_list):
|
||||
await q.put(p)
|
||||
await q.put(None)
|
||||
return q
|
||||
|
||||
|
||||
def _append_error(prog: ScanProgress, *, phase: Literal["fast", "slow"], path: str, message: str) -> None:
|
||||
def _append_error(prog: ScanProgress, *, path: str, message: str) -> None:
|
||||
prog.file_errors.append({
|
||||
"path": path,
|
||||
"message": message,
|
||||
"phase": phase,
|
||||
"at": _ts_to_iso(time.time()),
|
||||
"at": ts_to_iso(time.time()),
|
||||
})
|
||||
|
||||
|
||||
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
|
||||
if ts is None:
|
||||
return None
|
||||
# interpret ts as seconds since epoch UTC and return naive UTC (consistent with other models)
|
||||
try:
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc).replace(tzinfo=None).isoformat()
|
||||
except Exception:
|
||||
return None
|
||||
async def _fast_db_consistency_pass(root: schemas_in.RootType) -> None:
|
||||
"""
|
||||
Quick pass over asset_cache_state for `root`:
|
||||
- If file missing and Asset.hash is NULL and the Asset has no other states, delete the Asset and its infos.
|
||||
- If file missing and Asset.hash is NOT NULL:
|
||||
* If at least one state for this Asset is fast-ok, delete the missing state.
|
||||
* If none are fast-ok, add 'missing' tag to all AssetInfos for this Asset.
|
||||
- If at least one state becomes fast-ok for a hashed Asset, remove the 'missing' tag.
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||
|
||||
def _new_scan_id(root: RootType) -> str:
|
||||
return f"scan-{root}-{uuid.uuid4().hex[:8]}"
|
||||
async with await create_session() as sess:
|
||||
if not conds:
|
||||
return
|
||||
|
||||
|
||||
def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e: dict):
|
||||
if phase == "fast":
|
||||
if finished:
|
||||
logging.info(
|
||||
"[assets][%s] fast done: processed=%s/%s queued=%s",
|
||||
root,
|
||||
total_processed,
|
||||
e["discovered"],
|
||||
e["queued"],
|
||||
)
|
||||
elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress
|
||||
logging.info(
|
||||
"[assets][%s] fast progress: processed=%s/%s",
|
||||
root,
|
||||
total_processed,
|
||||
e["discovered"],
|
||||
)
|
||||
elif phase == "slow":
|
||||
if finished:
|
||||
if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0):
|
||||
logging.info(
|
||||
"[assets][%s] slow done: %s/%s",
|
||||
root,
|
||||
e.get("slow_queue_finished", 0),
|
||||
e.get("slow_queue_total", 0),
|
||||
)
|
||||
elif e.get('slow_queue_finished', 0) % 3 == 0:
|
||||
logging.info(
|
||||
"[assets][%s] slow progress: %s/%s",
|
||||
root,
|
||||
e.get("slow_queue_finished", 0),
|
||||
e.get("slow_queue_total", 0),
|
||||
rows = (
|
||||
await sess.execute(
|
||||
sa.select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group by asset_id with status per state
|
||||
by_asset: dict[str, dict] = {}
|
||||
for st, a_hash, a_size in rows:
|
||||
aid = st.asset_id
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
|
||||
exists = False
|
||||
fast_ok = False
|
||||
try:
|
||||
s = os.stat(st.file_path, follow_symlinks=True)
|
||||
exists = True
|
||||
actual_mtime_ns = getattr(s, "st_mtime_ns", int(s.st_mtime * 1_000_000_000))
|
||||
if st.mtime_ns is not None and int(st.mtime_ns) == int(actual_mtime_ns):
|
||||
if acc["size_db"] == 0 or int(s.st_size) == acc["size_db"]:
|
||||
fast_ok = True
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError as ex:
|
||||
exists = False
|
||||
LOGGER.debug("fast pass stat error for %s: %s", st.file_path, ex)
|
||||
|
||||
acc["states"].append({"obj": st, "exists": exists, "fast_ok": fast_ok})
|
||||
|
||||
# Apply actions
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
missing_states = [s["obj"] for s in states if not s["exists"]]
|
||||
|
||||
if a_hash is None:
|
||||
# Seed asset: if all states gone (and in practice there is only one), remove the whole Asset
|
||||
if states and all_missing:
|
||||
await sess.execute(sa.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = await sess.get(Asset, aid)
|
||||
if asset:
|
||||
await sess.delete(asset)
|
||||
# else leave it for the slow scan to verify/rehash
|
||||
else:
|
||||
if any_fast_ok:
|
||||
# Remove 'missing' and delete just the stale state rows
|
||||
for st in missing_states:
|
||||
try:
|
||||
await sess.delete(await sess.get(AssetCacheState, st.id))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# No fast-ok path: mark as missing
|
||||
try:
|
||||
await add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await sess.flush()
|
||||
await sess.commit()
|
||||
|
||||
@ -1,186 +0,0 @@
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, exists
|
||||
|
||||
from .models import AssetInfo, AssetInfoTag, Tag, AssetInfoMeta
|
||||
from .._assets_helpers import normalize_tags
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply metadata filters using the projection table asset_info_meta.
|
||||
|
||||
Semantics:
|
||||
- For scalar values: require EXISTS(asset_info_meta) with matching key + typed value.
|
||||
- For None: key is missing OR key has explicit null (val_json IS NULL).
|
||||
- For list values: ANY-of the list elements matches (EXISTS for any).
|
||||
(Change to ALL-of by 'for each element: stmt = stmt.where(_meta_exists_clause(key, elem))')
|
||||
"""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
# Missing OR null:
|
||||
if value is None:
|
||||
# either: no row for key OR a row for key with explicit null
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
# Typed scalar matches:
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
# store as Decimal for equality against NUMERIC(38,10)
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
|
||||
# Complex: compare JSON (no index, but supported)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
# ANY-of (exists for any element)
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def is_scalar(v: Any) -> bool:
|
||||
if v is None: # treat None as a value (explicit null) so it can be indexed for "is null" queries
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value: Any) -> list[dict]:
|
||||
"""
|
||||
Turn a metadata key/value into one or more projection rows:
|
||||
- scalar -> one row (ordinal=0) in the proper typed column
|
||||
- list of scalars -> one row per element with ordinal=i
|
||||
- dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list)
|
||||
- None -> single row with all value columns NULL
|
||||
Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...}
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
# store numeric; SQLAlchemy will coerce to Numeric
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
# Fallback to json
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
# list contains objects -> one val_json per element
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
# Dict or any other structure -> single json row
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
@ -4,14 +4,20 @@ import shutil
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from comfy.cli_args import args
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
ENGINE: Optional[AsyncEngine] = None
|
||||
|
||||
23
app/database/helpers/__init__.py
Normal file
23
app/database/helpers/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from .filters import apply_metadata_filter, apply_tag_filters
|
||||
from .ownership import visible_owner_clause
|
||||
from .projection import is_scalar, project_kv
|
||||
from .tags import (
|
||||
add_missing_tag_for_asset_hash,
|
||||
add_missing_tag_for_asset_id,
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_hash,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"apply_tag_filters",
|
||||
"apply_metadata_filter",
|
||||
"is_scalar",
|
||||
"project_kv",
|
||||
"ensure_tags_exist",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_missing_tag_for_asset_hash",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_missing_tag_for_asset_hash",
|
||||
"visible_owner_clause",
|
||||
]
|
||||
87
app/database/helpers/filters.py
Normal file
87
app/database/helpers/filters.py
Normal file
@ -0,0 +1,87 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import AssetInfo, AssetInfoMeta, AssetInfoTag
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
12
app/database/helpers/ownership.py
Normal file
12
app/database/helpers/ownership.py
Normal file
@ -0,0 +1,12 @@
|
||||
import sqlalchemy as sa
|
||||
|
||||
from ..models import AssetInfo
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
64
app/database/helpers/projection.py
Normal file
64
app/database/helpers/projection.py
Normal file
@ -0,0 +1,64 @@
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def project_kv(key: str, value):
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
102
app/database/helpers/tags.py
Normal file
102
app/database/helpers/tags.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..._assets_helpers import normalize_tags
|
||||
from ..models import Asset, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
|
||||
|
||||
async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_type: str = "user") -> list[Tag]:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return []
|
||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||
by_name = {t.name: t for t in existing}
|
||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
||||
if to_create:
|
||||
session.add_all(to_create)
|
||||
await session.flush()
|
||||
by_name.update({t.name: t for t in to_create})
|
||||
return [by_name[n] for n in wanted]
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> int:
|
||||
"""Ensure every AssetInfo for asset_id has 'missing' tag."""
|
||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||
if not ids:
|
||||
return 0
|
||||
|
||||
existing = {
|
||||
asset_info_id
|
||||
for (asset_info_id,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.asset_info_id).where(
|
||||
AssetInfoTag.asset_info_id.in_(ids),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = [i for i in ids if i not in existing]
|
||||
if not to_add:
|
||||
return 0
|
||||
|
||||
now = utcnow()
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(asset_info_id=i, tag_name="missing", origin=origin, added_at=now)
|
||||
for i in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
return len(to_add)
|
||||
|
||||
|
||||
async def add_missing_tag_for_asset_hash(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
origin: str = "automatic",
|
||||
) -> int:
|
||||
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||
if not asset:
|
||||
return 0
|
||||
return await add_missing_tag_for_asset_id(session, asset_id=asset.id, origin=origin)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> int:
|
||||
"""Remove the 'missing' tag from all AssetInfos for asset_id."""
|
||||
ids = (await session.execute(select(AssetInfo.id).where(AssetInfo.asset_id == asset_id))).scalars().all()
|
||||
if not ids:
|
||||
return 0
|
||||
|
||||
res = await session.execute(
|
||||
delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(ids),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def remove_missing_tag_for_asset_hash(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> int:
|
||||
asset = (await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))).scalars().first()
|
||||
if not asset:
|
||||
return 0
|
||||
return await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
@ -1,27 +1,26 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import (
|
||||
Integer,
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
CheckConstraint,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Index,
|
||||
UniqueConstraint,
|
||||
JSON,
|
||||
Integer,
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
CheckConstraint,
|
||||
Numeric,
|
||||
Boolean,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from .timeutil import utcnow
|
||||
|
||||
|
||||
JSONB_V = JSON(none_as_null=True).with_variant(JSONB(none_as_null=True), 'postgresql')
|
||||
|
||||
|
||||
@ -46,7 +45,8 @@ def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
hash: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
hash: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
@ -56,8 +56,8 @@ class Asset(Base):
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.asset_hash),
|
||||
foreign_keys=lambda: [AssetInfo.asset_hash],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
@ -65,8 +65,8 @@ class Asset(Base):
|
||||
preview_of: Mapped[list["AssetInfo"]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.hash == foreign(AssetInfo.preview_hash),
|
||||
foreign_keys=lambda: [AssetInfo.preview_hash],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
@ -76,36 +76,32 @@ class Asset(Base):
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
locations: Mapped[list["AssetLocation"]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_assets_mime_type", "mime_type"),
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset hash={self.hash[:12]}>"
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
@ -114,27 +110,7 @@ class AssetCacheState(Base):
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetLocation(Base):
|
||||
__tablename__ = "asset_locations"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False) # "gcs"
|
||||
locator: Mapped[str] = mapped_column(Text, nullable=False) # "gs://bucket/object"
|
||||
expected_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True)
|
||||
etag: Mapped[Optional[str]] = mapped_column(String(256), nullable=True)
|
||||
last_modified: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
|
||||
asset: Mapped["Asset"] = relationship(back_populates="locations")
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_hash", "provider", "locator", name="uq_asset_locations_triplet"),
|
||||
Index("ix_asset_locations_hash", "asset_hash"),
|
||||
Index("ix_asset_locations_provider", "provider"),
|
||||
)
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
@ -143,31 +119,23 @@ class AssetInfo(Base):
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_hash: Mapped[str] = mapped_column(
|
||||
String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False
|
||||
)
|
||||
preview_hash: Mapped[Optional[str]] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[Optional[str]] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[Optional[dict[str, Any]]] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
|
||||
# Relationships
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
foreign_keys=[asset_hash],
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Optional[Asset]] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
foreign_keys=[preview_hash],
|
||||
foreign_keys=[preview_id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list["AssetInfoMeta"]] = relationship(
|
||||
@ -186,16 +154,16 @@ class AssetInfo(Base):
|
||||
tags: Mapped[list["Tag"]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
lazy="joined",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"),
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_hash", "asset_hash"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
@ -207,7 +175,7 @@ class AssetInfo(Base):
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} hash={self.asset_hash[:12]}>"
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
56
app/database/services/__init__.py
Normal file
56
app/database/services/__init__.py
Normal file
@ -0,0 +1,56 @@
|
||||
from .content import (
|
||||
check_fs_asset_exists_quick,
|
||||
compute_hash_and_dedup_for_cache_state,
|
||||
ensure_seed_for_path,
|
||||
ingest_fs_asset,
|
||||
list_cache_states_with_asset_under_prefixes,
|
||||
list_unhashed_candidates_under_prefixes,
|
||||
list_verify_candidates_under_prefixes,
|
||||
redirect_all_references_then_delete_asset,
|
||||
touch_asset_infos_by_fs_path,
|
||||
)
|
||||
from .info import (
|
||||
add_tags_to_asset_info,
|
||||
create_asset_info_for_existing_asset,
|
||||
delete_asset_info_by_id,
|
||||
fetch_asset_info_and_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_tags,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
replace_asset_info_metadata_projection,
|
||||
set_asset_info_preview,
|
||||
set_asset_info_tags,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
)
|
||||
from .queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
get_cache_state_by_asset_id,
|
||||
list_cache_states_by_asset_id,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# queries
|
||||
"asset_exists_by_hash", "get_asset_by_hash", "get_asset_info_by_id", "asset_info_exists_for_asset_id",
|
||||
"get_cache_state_by_asset_id",
|
||||
"list_cache_states_by_asset_id",
|
||||
# info
|
||||
"list_asset_infos_page", "create_asset_info_for_existing_asset", "set_asset_info_tags",
|
||||
"update_asset_info_full", "replace_asset_info_metadata_projection",
|
||||
"touch_asset_info_by_id", "delete_asset_info_by_id",
|
||||
"add_tags_to_asset_info", "remove_tags_from_asset_info",
|
||||
"get_asset_tags", "list_tags_with_usage", "set_asset_info_preview",
|
||||
"fetch_asset_info_and_asset", "fetch_asset_info_asset_and_tags",
|
||||
# content
|
||||
"check_fs_asset_exists_quick", "ensure_seed_for_path",
|
||||
"redirect_all_references_then_delete_asset",
|
||||
"compute_hash_and_dedup_for_cache_state",
|
||||
"list_unhashed_candidates_under_prefixes", "list_verify_candidates_under_prefixes",
|
||||
"ingest_fs_asset", "touch_asset_infos_by_fs_path",
|
||||
"list_cache_states_with_asset_under_prefixes",
|
||||
]
|
||||
746
app/database/services/content.py
Normal file
746
app/database/services/content.py
Normal file
@ -0,0 +1,746 @@
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import postgresql as d_pg
|
||||
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||
from ...storage import hashing as hashing_mod
|
||||
from ..helpers import (
|
||||
ensure_tags_exist,
|
||||
remove_missing_tag_for_asset_id,
|
||||
)
|
||||
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .info import replace_asset_info_metadata_projection
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
size_bytes: Optional[int] = None,
|
||||
mtime_ns: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Return True if a cache row exists for this absolute path and (optionally) mtime/size match."""
|
||||
locator = os.path.abspath(file_path)
|
||||
|
||||
stmt = (
|
||||
sa.select(sa.literal(True))
|
||||
.select_from(AssetCacheState)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
conds = []
|
||||
if mtime_ns is not None:
|
||||
conds.append(AssetCacheState.mtime_ns == int(mtime_ns))
|
||||
if size_bytes is not None:
|
||||
conds.append(sa.or_(Asset.size_bytes == 0, Asset.size_bytes == int(size_bytes)))
|
||||
|
||||
if conds:
|
||||
stmt = stmt.where(*conds)
|
||||
|
||||
row = (await session.execute(stmt)).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def ensure_seed_for_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
info_name: str,
|
||||
tags: Sequence[str],
|
||||
owner_id: str = "",
|
||||
) -> str:
|
||||
"""Ensure: Asset(hash=NULL), AssetCacheState(file_path), and AssetInfo exist for the path. Returns asset_id."""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
state = (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState, Asset)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
if state:
|
||||
state_row: AssetCacheState = state[0]
|
||||
asset_row: Asset = state[1]
|
||||
changed = state_row.mtime_ns is None or int(state_row.mtime_ns) != int(mtime_ns)
|
||||
if changed:
|
||||
state_row.mtime_ns = int(mtime_ns)
|
||||
state_row.needs_verify = True
|
||||
if asset_row.size_bytes == 0 and size_bytes > 0:
|
||||
asset_row.size_bytes = int(size_bytes)
|
||||
return asset_row.id
|
||||
|
||||
# Create new asset (hash=NULL)
|
||||
asset = Asset(hash=None, size_bytes=int(size_bytes), mime_type=None, created_at=now)
|
||||
session.add(asset)
|
||||
await session.flush() # to get id
|
||||
|
||||
cs = AssetCacheState(asset_id=asset.id, file_path=locator, mtime_ns=int(mtime_ns), needs_verify=False)
|
||||
session.add(cs)
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
|
||||
# Attach tags
|
||||
want = normalize_tags(tags)
|
||||
if want:
|
||||
await ensure_tags_exist(session, want, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=info.id, tag_name=t, origin="automatic", added_at=now)
|
||||
for t in want
|
||||
])
|
||||
|
||||
await session.flush()
|
||||
return asset.id
|
||||
|
||||
|
||||
async def redirect_all_references_then_delete_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
duplicate_asset_id: str,
|
||||
canonical_asset_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Safely migrate all references from duplicate_asset_id to canonical_asset_id.
|
||||
|
||||
- If an AssetInfo for (owner_id, name) already exists on the canonical asset,
|
||||
merge tags, metadata, times, and preview, then delete the duplicate AssetInfo.
|
||||
- Otherwise, simply repoint the AssetInfo.asset_id.
|
||||
- Always retarget AssetCacheState rows.
|
||||
- Finally delete the duplicate Asset row.
|
||||
"""
|
||||
if duplicate_asset_id == canonical_asset_id:
|
||||
return
|
||||
|
||||
# 1) Migrate AssetInfo rows one-by-one to avoid UNIQUE conflicts.
|
||||
dup_infos = (
|
||||
await session.execute(
|
||||
select(AssetInfo).options(noload(AssetInfo.tags)).where(AssetInfo.asset_id == duplicate_asset_id)
|
||||
)
|
||||
).unique().scalars().all()
|
||||
|
||||
for info in dup_infos:
|
||||
# Try to find an existing collision on canonical
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == canonical_asset_id,
|
||||
AssetInfo.owner_id == info.owner_id,
|
||||
AssetInfo.name == info.name,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
|
||||
if existing:
|
||||
# Merge metadata (prefer existing keys, fill gaps from duplicate)
|
||||
merged_meta = dict(existing.user_metadata or {})
|
||||
other_meta = info.user_metadata or {}
|
||||
for k, v in other_meta.items():
|
||||
if k not in merged_meta:
|
||||
merged_meta[k] = v
|
||||
if merged_meta != (existing.user_metadata or {}):
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=existing.id,
|
||||
user_metadata=merged_meta,
|
||||
)
|
||||
|
||||
# Merge tags (union)
|
||||
existing_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == existing.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
from_tags = {
|
||||
t for (t,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == info.id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
to_add = sorted(from_tags - existing_tags)
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
now = utcnow()
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=existing.id, tag_name=t, origin="automatic", added_at=now)
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
# Merge preview and times
|
||||
if existing.preview_id is None and info.preview_id is not None:
|
||||
existing.preview_id = info.preview_id
|
||||
if info.last_access_time and (
|
||||
existing.last_access_time is None or info.last_access_time > existing.last_access_time
|
||||
):
|
||||
existing.last_access_time = info.last_access_time
|
||||
existing.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# Delete the duplicate AssetInfo (cascades will clean its tags/meta)
|
||||
await session.delete(info)
|
||||
await session.flush()
|
||||
else:
|
||||
# Simple retarget
|
||||
info.asset_id = canonical_asset_id
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
# 2) Repoint cache states and previews
|
||||
await session.execute(
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == duplicate_asset_id)
|
||||
.values(asset_id=canonical_asset_id)
|
||||
)
|
||||
await session.execute(
|
||||
sa.update(AssetInfo)
|
||||
.where(AssetInfo.preview_id == duplicate_asset_id)
|
||||
.values(preview_id=canonical_asset_id)
|
||||
)
|
||||
|
||||
# 3) Remove duplicate Asset
|
||||
dup = await session.get(Asset, duplicate_asset_id)
|
||||
if dup:
|
||||
await session.delete(dup)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def compute_hash_and_dedup_for_cache_state(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
state_id: int,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Compute hash for the given cache state, deduplicate, and settle verify cases.
|
||||
|
||||
Returns the asset_id that this state ends up pointing to, or None if file disappeared.
|
||||
"""
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if not state:
|
||||
return None
|
||||
|
||||
path = state.file_path
|
||||
try:
|
||||
if not os.path.isfile(path):
|
||||
# File vanished: drop the state. If the Asset was a seed (hash NULL)
|
||||
# and has no other states, drop the Asset too.
|
||||
asset = await session.get(Asset, state.asset_id)
|
||||
await session.delete(state)
|
||||
await session.flush()
|
||||
|
||||
if asset and asset.hash is None:
|
||||
remaining = (
|
||||
await session.execute(
|
||||
sa.select(sa.func.count())
|
||||
.select_from(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
)
|
||||
).scalar_one()
|
||||
if int(remaining or 0) == 0:
|
||||
await session.delete(asset)
|
||||
await session.flush()
|
||||
return None
|
||||
|
||||
digest = await hashing_mod.blake3_hash(path)
|
||||
new_hash = f"blake3:{digest}"
|
||||
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
new_size = int(st.st_size)
|
||||
mtime_ns = getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
# Current asset of this state
|
||||
this_asset = await session.get(Asset, state.asset_id)
|
||||
|
||||
# If the state got orphaned somehow (race), just reattach appropriately.
|
||||
if not this_asset:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
state.asset_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
state.asset_id = new_asset.id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=state.asset_id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return state.asset_id
|
||||
|
||||
# 1) Seed asset case (hash is NULL): claim or merge into canonical
|
||||
if this_asset.hash is None:
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
# Merge seed asset into canonical (safe, collision-aware)
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
# Refresh state after the merge
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
|
||||
# No canonical: try to claim the hash; handle races with a SAVEPOINT
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
this_asset.hash = new_hash
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Someone else claimed it concurrently; fetch canonical and merge
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical and canonical.id != this_asset.id:
|
||||
await redirect_all_references_then_delete_asset(
|
||||
session,
|
||||
duplicate_asset_id=this_asset.id,
|
||||
canonical_asset_id=canonical.id,
|
||||
)
|
||||
state = await session.get(AssetCacheState, state_id)
|
||||
if state:
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=canonical.id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return canonical.id
|
||||
# If we got here, the integrity error was not about hash uniqueness
|
||||
raise
|
||||
|
||||
# Claimed successfully
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# 2) Verify case for hashed assets
|
||||
if this_asset.hash == new_hash:
|
||||
# Content unchanged; tidy up sizes/mtime
|
||||
if int(this_asset.size_bytes or 0) == 0 and new_size > 0:
|
||||
this_asset.size_bytes = new_size
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=this_asset.id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return this_asset.id
|
||||
|
||||
# Content changed on this path only: retarget THIS state, do not move AssetInfo rows
|
||||
canonical = (
|
||||
await session.execute(sa.select(Asset).where(Asset.hash == new_hash).limit(1))
|
||||
).scalars().first()
|
||||
if canonical:
|
||||
target_id = canonical.id
|
||||
else:
|
||||
now = utcnow()
|
||||
new_asset = Asset(hash=new_hash, size_bytes=new_size, mime_type=None, created_at=now)
|
||||
session.add(new_asset)
|
||||
await session.flush()
|
||||
target_id = new_asset.id
|
||||
|
||||
state.asset_id = target_id
|
||||
state.mtime_ns = mtime_ns
|
||||
state.needs_verify = False
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=target_id)
|
||||
except Exception:
|
||||
pass
|
||||
await session.flush()
|
||||
return target_id
|
||||
|
||||
except Exception:
|
||||
# Propagate; caller records the error and continues the worker.
|
||||
raise
|
||||
|
||||
|
||||
async def list_unhashed_candidates_under_prefixes(
|
||||
session: AsyncSession, *, prefixes: Sequence[str]
|
||||
) -> list[int]:
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||
|
||||
rows = (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState.id)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(Asset.hash.is_(None))
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
seen = set()
|
||||
result: list[int] = []
|
||||
for sid in rows:
|
||||
st = await session.get(AssetCacheState, sid)
|
||||
if st and st.asset_id not in seen:
|
||||
seen.add(st.asset_id)
|
||||
result.append(sid)
|
||||
return result
|
||||
|
||||
|
||||
async def list_verify_candidates_under_prefixes(
|
||||
session: AsyncSession, *, prefixes: Sequence[str]
|
||||
) -> Union[list[int], Sequence[int]]:
|
||||
if not prefixes:
|
||||
return []
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||
|
||||
return (
|
||||
await session.execute(
|
||||
sa.select(AssetCacheState.id)
|
||||
.where(AssetCacheState.needs_verify.is_(True))
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
async def ingest_fs_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: Optional[str] = None,
|
||||
info_name: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
preview_id: Optional[str] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not await session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
async with session.begin_nested():
|
||||
asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now)
|
||||
session.add(asset)
|
||||
await session.flush()
|
||||
out["asset_created"] = True
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
dialect = session.bind.dialect.name
|
||||
if dialect == "sqlite":
|
||||
ins = (
|
||||
d_sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
elif dialect == "postgresql":
|
||||
ins = (
|
||||
d_pg.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||
|
||||
res = await session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = await session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
# upsert by (asset_id, owner_id, name)
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
await session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = (
|
||||
await session.execute(
|
||||
select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.asset_id == asset.id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
await remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
async def touch_asset_infos_by_fs_path(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
file_path: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
locator = os.path.abspath(file_path)
|
||||
ts = ts or utcnow()
|
||||
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(AssetCacheState)
|
||||
.where(
|
||||
AssetCacheState.asset_id == AssetInfo.asset_id,
|
||||
AssetCacheState.file_path == locator,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(
|
||||
AssetInfo.last_access_time.is_(None),
|
||||
AssetInfo.last_access_time < ts,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def list_cache_states_with_asset_under_prefixes(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefixes: Sequence[str],
|
||||
) -> list[tuple[AssetCacheState, Optional[str], int]]:
|
||||
"""Return (AssetCacheState, asset_hash, size_bytes) for rows under any prefix."""
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
if not p:
|
||||
continue
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base = base + os.sep
|
||||
conds.append(AssetCacheState.file_path.like(base + "%"))
|
||||
|
||||
if not conds:
|
||||
return []
|
||||
|
||||
rows = (
|
||||
await session.execute(
|
||||
select(AssetCacheState, Asset.hash, Asset.size_bytes)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sa.or_(*conds))
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).all()
|
||||
return [(r[0], r[1], int(r[2] or 0)) for r in rows]
|
||||
579
app/database/services/info.py
Normal file
579
app/database/services/info.py
Normal file
@ -0,0 +1,579 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import contains_eager, noload
|
||||
|
||||
from ..._assets_helpers import compute_model_relative_filename, normalize_tags
|
||||
from ..helpers import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
ensure_tags_exist,
|
||||
project_kv,
|
||||
visible_owner_clause,
|
||||
)
|
||||
from ..models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from ..timeutil import utcnow
|
||||
from .queries import get_asset_by_hash, get_cache_state_by_asset_id
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (await session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = await session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (await session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
async def create_asset_info_for_existing_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
session.add(info)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
state = await get_cache_state_by_asset_id(session, asset_id=asset.id)
|
||||
if state and state.file_path:
|
||||
computed_filename = compute_model_relative_filename(state.file_path)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
await replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
await session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
await ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
async def update_asset_info_full(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
state = await get_cache_state_by_asset_id(session, asset_id=info.asset_id)
|
||||
if state and state.file_path:
|
||||
computed_filename = compute_model_relative_filename(state.file_path)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
await replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
await set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: Optional[dict],
|
||||
) -> None:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
await session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
await session.flush()
|
||||
|
||||
|
||||
async def touch_asset_info_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: str, owner_id: str) -> bool:
|
||||
res = await session.execute(delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
))
|
||||
return bool(res.rowcount)
|
||||
|
||||
|
||||
async def add_tags_to_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
await ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
async with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
await nested.rollback()
|
||||
|
||||
after = set(await get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
async def remove_tags_from_asset_info(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
await session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
await session.flush()
|
||||
|
||||
total = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (await session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (await session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
async def get_asset_tags(session: AsyncSession, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
await session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
async def set_asset_info_preview(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: Optional[str],
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not await session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
59
app/database/services/queries.py
Normal file
59
app/database/services/queries.py
Normal file
@ -0,0 +1,59 @@
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ..models import Asset, AssetCacheState, AssetInfo
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
row = (
|
||||
await session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Optional[Asset]:
|
||||
return (
|
||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def asset_info_exists_for_asset_id(session: AsyncSession, *, asset_id: str) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (await session.execute(q)).first() is not None
|
||||
|
||||
|
||||
async def get_cache_state_by_asset_id(session: AsyncSession, *, asset_id: str) -> Optional[AssetCacheState]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
|
||||
|
||||
async def list_cache_states_by_asset_id(
|
||||
session: AsyncSession, *, asset_id: str
|
||||
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
|
||||
return (
|
||||
await session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
@ -212,7 +212,6 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite+aiosqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite+aiosqlite:///:memory:'.")
|
||||
parser.add_argument("--enable-model-processing", action="store_true", help="Enable automatic processing of the model file, such as calculating hashes and populating the database.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
|
||||
4
main.py
4
main.py
@ -279,11 +279,11 @@ def cleanup_temp():
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
|
||||
async def setup_database():
|
||||
from app import init_db_engine, start_background_assets_scan
|
||||
from app import init_db_engine, sync_seed_assets
|
||||
|
||||
await init_db_engine()
|
||||
if not args.disable_assets_autoscan:
|
||||
await start_background_assets_scan()
|
||||
await sync_seed_assets(["models", "input", "output"])
|
||||
|
||||
|
||||
def start_comfyui(asyncio_loop=None):
|
||||
|
||||
@ -37,7 +37,7 @@ from app.model_manager import ModelFileManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.api.assets_routes import register_assets_system
|
||||
from app import sync_seed_assets, register_assets_system
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
@ -629,6 +629,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
await sync_seed_assets(["models"])
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
@ -118,6 +118,16 @@ async def test_head_asset_by_hash(http: aiohttp.ClientSession, api_base: str, se
|
||||
assert rh2.status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_bad_hash_returns_400_and_no_body(http: aiohttp.ClientSession, api_base: str):
|
||||
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||
# aiohttp exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh:
|
||||
assert rh.status == 400
|
||||
body = await rh.read()
|
||||
assert body == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_returns_404(http: aiohttp.ClientSession, api_base: str):
|
||||
bogus = str(uuid.uuid4())
|
||||
@ -166,12 +176,3 @@ async def test_update_requires_at_least_one_field(http: aiohttp.ClientSession, a
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_head_asset_bad_hash(http: aiohttp.ClientSession, api_base: str):
|
||||
# Invalid format
|
||||
async with http.head(f"{api_base}/api/assets/hash/not_a_hash") as rh3:
|
||||
jb = await rh3.json()
|
||||
assert rh3.status == 400
|
||||
assert jb is None # HEAD request should not include "body" in response
|
||||
|
||||
@ -66,23 +66,32 @@ async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, s
|
||||
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 200, b1
|
||||
# normalized and deduplicated
|
||||
assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"]
|
||||
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||
assert set(b1["added"]) == {"newtag", "beta"}
|
||||
assert set(b1["already_present"]) == {"unit-tests"}
|
||||
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||
g = await rg.json()
|
||||
assert rg.status == 200
|
||||
tags_now = set(g["tags"])
|
||||
assert "newtag" in tags_now
|
||||
assert "beta" in tags_now
|
||||
assert {"newtag", "beta"}.issubset(tags_now)
|
||||
|
||||
# Remove a tag and a non-existent tag
|
||||
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert "newtag" in b2["removed"]
|
||||
assert "does-not-exist" in b2["not_present"]
|
||||
assert set(b2["removed"]) == {"newtag"}
|
||||
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||
|
||||
# Verify remaining tags after deletion
|
||||
async with http.get(f"{api_base}/api/assets/{aid}") as rg2:
|
||||
g2 = await rg2.json()
|
||||
assert rg2.status == 200
|
||||
tags_later = set(g2["tags"])
|
||||
assert "newtag" not in tags_later
|
||||
assert "beta" in tags_later # still present
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -206,7 +206,7 @@ async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_b
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
|
||||
assert body["error"]["message"].startswith("unknown models category")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user