add "--multi-user" support

This commit is contained in:
bigcat88 2025-08-27 19:47:55 +03:00
parent 7c1b0be496
commit 026b7f209c
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
5 changed files with 136 additions and 56 deletions

View File

@ -2,8 +2,12 @@ import os
from pathlib import Path from pathlib import Path
from typing import Optional, Literal, Sequence from typing import Optional, Literal, Sequence
import sqlalchemy as sa
import folder_paths import folder_paths
from .database.models import AssetInfo
def get_comfy_models_folders() -> list[tuple[str, list[str]]]: def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations. """Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory") raise ValueError("destination escapes base directory")
except Exception: except Exception:
raise ValueError("invalid destination path") raise ValueError("invalid destination path")
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@ -8,11 +8,12 @@ from pydantic import ValidationError
import folder_paths import folder_paths
from .. import assets_manager, assets_scanner from .. import assets_manager, assets_scanner, user_manager
from . import schemas_in, schemas_out from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
UserManager: Optional[user_manager.UserManager] = None
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response:
offset=q.offset, offset=q.offset,
sort=q.sort, sort=q.sort,
order=q.order, order=q.order,
owner_id=UserManager.get_request_user_id(request),
) )
return web.json_response(payload.model_dump(mode="json")) return web.json_response(payload.model_dump(mode="json"))
@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response:
try: try:
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
asset_info_id=asset_info_id asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
) )
except ValueError as ve: except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve)) return _error_response(404, "ASSET_NOT_FOUND", str(ve))
@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
name=body.name, name=body.name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=UserManager.get_request_user_id(request),
) )
if result is None: if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response:
spec, spec,
temp_path=tmp_path, temp_path=tmp_path,
client_filename=file_client_name, client_filename=file_client_name,
owner_id=UserManager.get_request_user_id(request),
) )
return web.json_response(created.model_dump(mode="json"), status=201) return web.json_response(created.model_dump(mode="json"), status=201)
except ValueError: except ValueError:
@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try: try:
result = await assets_manager.get_asset(asset_info_id=asset_info_id) result = await assets_manager.get_asset(
asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve: except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception: except Exception:
@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response:
name=body.name, name=body.name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=UserManager.get_request_user_id(request),
) )
except ValueError as ve: except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception: except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.") return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try: try:
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id) deleted = await assets_manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
)
except Exception: except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.") return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response:
offset=query.offset, offset=query.offset,
order=query.order, order=query.order,
include_zero=query.include_zero, include_zero=query.include_zero,
owner_id=UserManager.get_request_user_id(request),
) )
return web.json_response(result.model_dump(mode="json")) return web.json_response(result.model_dump(mode="json"))
@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response:
tags=data.tags, tags=data.tags,
origin="manual", origin="manual",
added_by=None, added_by=None,
owner_id=UserManager.get_request_user_id(request),
) )
except ValueError as ve: except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception: except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.") return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
result = await assets_manager.remove_tags_from_asset( result = await assets_manager.remove_tags_from_asset(
asset_info_id=asset_info_id, asset_info_id=asset_info_id,
tags=data.tags, tags=data.tags,
owner_id=UserManager.get_request_user_id(request),
) )
except ValueError as ve: except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
return web.json_response(states.model_dump(mode="json"), status=200) return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_routes(app: web.Application) -> None: def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global UserManager
UserManager = user_manager_instance
app.add_routes(ROUTES) app.add_routes(ROUTES)

View File

@ -25,9 +25,14 @@ from .database.services import (
get_asset_by_hash, get_asset_by_hash,
create_asset_info_for_existing_asset, create_asset_info_for_existing_asset,
fetch_asset_info_asset_and_tags, fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
) )
from .api import schemas_in, schemas_out from .api import schemas_in, schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags from ._assets_helpers import (
get_name_and_tags_from_asset_path,
ensure_within_base,
resolve_destination_from_tags,
)
from .assets_fetcher import ensure_asset_cached from .assets_fetcher import ensure_asset_cached
@ -98,6 +103,7 @@ async def list_assets(
offset: int = 0, offset: int = 0,
sort: str = "created_at", sort: str = "created_at",
order: str = "desc", order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList: ) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort) sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@ -105,6 +111,7 @@ async def list_assets(
async with await create_session() as session: async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page( infos, tag_map, total = await list_asset_infos_page(
session, session,
owner_id=owner_id,
include_tags=include_tags, include_tags=include_tags,
exclude_tags=exclude_tags, exclude_tags=exclude_tags,
name_contains=name_contains, name_contains=name_contains,
@ -141,9 +148,9 @@ async def list_assets(
) )
async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session: async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id) res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res: if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res info, asset, tag_names = res
@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
async def resolve_asset_content_for_download( async def resolve_asset_content_for_download(
*, asset_info_id: int *,
asset_info_id: int,
owner_id: str = "",
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
""" """
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
@ -173,7 +182,7 @@ async def resolve_asset_content_for_download(
ValueError if AssetInfo cannot be found ValueError if AssetInfo cannot be found
""" """
async with await create_session() as session: async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair: if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
@ -198,6 +207,7 @@ async def upload_asset_from_temp_path(
*, *,
temp_path: str, temp_path: str,
client_filename: Optional[str] = None, client_filename: Optional[str] = None,
owner_id: str = "",
) -> schemas_out.AssetCreated: ) -> schemas_out.AssetCreated:
""" """
Finalize an uploaded temp file: Finalize an uploaded temp file:
@ -250,7 +260,7 @@ async def upload_asset_from_temp_path(
mtime_ns=mtime_ns, mtime_ns=mtime_ns,
mime_type=content_type, mime_type=content_type,
info_name=os.path.basename(dest_abs), info_name=os.path.basename(dest_abs),
owner_id="", owner_id=owner_id,
preview_hash=None, preview_hash=None,
user_metadata=spec.user_metadata or {}, user_metadata=spec.user_metadata or {},
tags=spec.tags, tags=spec.tags,
@ -262,7 +272,7 @@ async def upload_asset_from_temp_path(
if not info_id: if not info_id:
raise RuntimeError("failed to create asset metadata") raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id)) pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id)
if not pair: if not pair:
raise RuntimeError("inconsistent DB state after ingest") raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair info, asset = pair
@ -290,8 +300,15 @@ async def update_asset(
name: Optional[str] = None, name: Optional[str] = None,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None, user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated: ) -> schemas_out.AssetUpdated:
async with await create_session() as session: async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full( info = await update_asset_info_full(
session, session,
asset_info_id=asset_info_id, asset_info_id=asset_info_id,
@ -300,6 +317,7 @@ async def update_asset(
user_metadata=user_metadata, user_metadata=user_metadata,
tag_origin="manual", tag_origin="manual",
added_by=None, added_by=None,
asset_info_row=info_row,
) )
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
@ -315,9 +333,9 @@ async def update_asset(
) )
async def delete_asset_reference(*, asset_info_id: int) -> bool: async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool:
async with await create_session() as session: async with await create_session() as session:
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id) r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
await session.commit() await session.commit()
return r return r
@ -328,6 +346,7 @@ async def create_asset_from_hash(
name: str, name: str,
tags: Optional[list[str]] = None, tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None, user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]: ) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower() canonical = hash_str.strip().lower()
async with await create_session() as session: async with await create_session() as session:
@ -343,6 +362,7 @@ async def create_asset_from_hash(
tags=tags or [], tags=tags or [],
tag_origin="manual", tag_origin="manual",
added_by=None, added_by=None,
owner_id=owner_id,
) )
tag_names = await get_asset_tags(session, asset_info_id=info.id) tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit() await session.commit()
@ -369,6 +389,7 @@ async def list_tags(
offset: int = 0, offset: int = 0,
order: str = "count_desc", order: str = "count_desc",
include_zero: bool = True, include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList: ) -> schemas_out.TagsList:
limit = max(1, min(1000, limit)) limit = max(1, min(1000, limit))
offset = max(0, offset) offset = max(0, offset)
@ -381,6 +402,7 @@ async def list_tags(
offset=offset, offset=offset,
include_zero=include_zero, include_zero=include_zero,
order=order, order=order,
owner_id=owner_id,
) )
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
@ -393,8 +415,14 @@ async def add_tags_to_asset(
tags: list[str], tags: list[str],
origin: str = "manual", origin: str = "manual",
added_by: Optional[str] = None, added_by: Optional[str] = None,
owner_id: str = "",
) -> schemas_out.TagsAdd: ) -> schemas_out.TagsAdd:
async with await create_session() as session: async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info( data = await add_tags_to_asset_info(
session, session,
asset_info_id=asset_info_id, asset_info_id=asset_info_id,
@ -402,6 +430,7 @@ async def add_tags_to_asset(
origin=origin, origin=origin,
added_by=added_by, added_by=added_by,
create_if_missing=True, create_if_missing=True,
asset_info_row=info_row,
) )
await session.commit() await session.commit()
return schemas_out.TagsAdd(**data) return schemas_out.TagsAdd(**data)
@ -411,8 +440,15 @@ async def remove_tags_from_asset(
*, *,
asset_info_id: int, asset_info_id: int,
tags: list[str], tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove: ) -> schemas_out.TagsRemove:
async with await create_session() as session: async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info( data = await remove_tags_from_asset_info(
session, session,
asset_info_id=asset_info_id, asset_info_id=asset_info_id,

View File

@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
from .timeutil import utcnow from .timeutil import utcnow
from .._assets_helpers import normalize_tags from .._assets_helpers import normalize_tags, visible_owner_clause
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option
return await session.get(Asset, asset_hash) return await session.get(Asset, asset_hash)
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def check_fs_asset_exists_quick( async def check_fs_asset_exists_quick(
session, session,
*, *,
@ -317,6 +321,7 @@ async def touch_asset_info_by_id(
async def list_asset_infos_page( async def list_asset_infos_page(
session: AsyncSession, session: AsyncSession,
*, *,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None, include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None, name_contains: Optional[str] = None,
@ -326,26 +331,18 @@ async def list_asset_infos_page(
sort: str = "created_at", sort: str = "created_at",
order: str = "desc", order: str = "desc",
) -> tuple[list[AssetInfo], dict[int, list[str]], int]: ) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
""" """Return page of AssetInfo rows in the viewers visibility."""
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
plus a map of asset_info_id -> [tags], and the total count.
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Build base query
base = ( base = (
select(AssetInfo) select(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash) .join(Asset, Asset.hash == AssetInfo.asset_hash)
.options(contains_eager(AssetInfo.asset)) .options(contains_eager(AssetInfo.asset))
.where(visible_owner_clause(owner_id))
) )
# Filters
if name_contains: if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = _apply_tag_filters(base, include_tags, exclude_tags) base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter) base = _apply_metadata_filter(base, metadata_filter)
# Sort # Sort
@ -368,13 +365,14 @@ async def list_asset_infos_page(
select(func.count()) select(func.count())
.select_from(AssetInfo) .select_from(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash) .join(Asset, Asset.hash == AssetInfo.asset_hash)
.where(visible_owner_clause(owner_id))
) )
if name_contains: if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = (await session.execute(count_stmt)).scalar_one() total = int((await session.execute(count_stmt)).scalar_one() or 0)
# Fetch rows # Fetch rows
infos = (await session.execute(base)).scalars().unique().all() infos = (await session.execute(base)).scalars().unique().all()
@ -394,13 +392,22 @@ async def list_asset_infos_page(
return infos, tag_map, total return infos, tag_map, total
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]: async def fetch_asset_info_and_asset(
row = await session.execute( session: AsyncSession,
*,
asset_info_id: int,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset) select(AssetInfo, Asset)
.join(Asset, Asset.hash == AssetInfo.asset_hash) .join(Asset, Asset.hash == AssetInfo.asset_hash)
.where(AssetInfo.id == asset_info_id) .where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1) .limit(1)
) )
row = await session.execute(stmt)
pair = row.first() pair = row.first()
if not pair: if not pair:
return None return None
@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags(
session: AsyncSession, session: AsyncSession,
*, *,
asset_info_id: int, asset_info_id: int,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]: ) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
"""Fetch AssetInfo, its Asset, and all tag names.
Returns:
(AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist.
"""
stmt = ( stmt = (
select(AssetInfo, Asset, Tag.name) select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.hash == AssetInfo.asset_hash) .join(Asset, Asset.hash == AssetInfo.asset_hash)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(AssetInfo.id == asset_info_id) .where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags)) .options(noload(AssetInfo.tags))
.order_by(Tag.name.asc()) .order_by(Tag.name.asc())
) )
@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset(
tags: Optional[Sequence[str]] = None, tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual", tag_origin: str = "manual",
added_by: Optional[str] = None, added_by: Optional[str] = None,
owner_id: str = "",
) -> AssetInfo: ) -> AssetInfo:
"""Create a new AssetInfo referencing an existing Asset (no content write).""" """Create a new AssetInfo referencing an existing Asset (no content write)."""
now = utcnow() now = utcnow()
info = AssetInfo( info = AssetInfo(
owner_id="", owner_id=owner_id,
name=name, name=name,
asset_hash=asset_hash, asset_hash=asset_hash,
preview_hash=None, preview_hash=None,
@ -593,6 +600,7 @@ async def update_asset_info_full(
user_metadata: Optional[dict] = None, user_metadata: Optional[dict] = None,
tag_origin: str = "manual", tag_origin: str = "manual",
added_by: Optional[str] = None, added_by: Optional[str] = None,
asset_info_row: Any = None,
) -> AssetInfo: ) -> AssetInfo:
""" """
Update AssetInfo fields: Update AssetInfo fields:
@ -601,9 +609,12 @@ async def update_asset_info_full(
- replace tags with provided set (if provided) - replace tags with provided set (if provided)
Returns the updated AssetInfo. Returns the updated AssetInfo.
""" """
info = await session.get(AssetInfo, asset_info_id) if not asset_info_row:
if not info: info = await session.get(AssetInfo, asset_info_id)
raise ValueError(f"AssetInfo {asset_info_id} not found") if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False touched = False
if name is not None and name != info.name: if name is not None and name != info.name:
@ -633,9 +644,12 @@ async def update_asset_info_full(
return info return info
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool: async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool:
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" """Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id)) res = await session.execute(delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
))
return bool(res.rowcount) return bool(res.rowcount)
@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
async def list_tags_with_usage( async def list_tags_with_usage(
session, session: AsyncSession,
*, *,
prefix: Optional[str] = None, prefix: Optional[str] = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
include_zero: bool = True, include_zero: bool = True,
order: str = "count_desc", # "count_desc" | "name_asc" order: str = "count_desc", # "count_desc" | "name_asc"
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]: ) -> tuple[list[tuple[str, str, int]], int]:
""" # Subquery with counts by tag_name and owner_id
Returns:
rows: list of (name, tag_type, count)
total: number of tags matching filter (independent of pagination)
"""
# Subquery with counts by tag_name
counts_sq = ( counts_sq = (
select( select(
AssetInfoTag.tag_name.label("tag_name"), AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"), func.count(AssetInfoTag.asset_info_id).label("cnt"),
) )
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name) .group_by(AssetInfoTag.tag_name)
.subquery() .subquery()
) )
@ -765,14 +778,16 @@ async def add_tags_to_asset_info(
origin: str = "manual", origin: str = "manual",
added_by: Optional[str] = None, added_by: Optional[str] = None,
create_if_missing: bool = True, create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict: ) -> dict:
"""Adds tags to an AssetInfo. """Adds tags to an AssetInfo.
If create_if_missing=True, missing tag rows are created as 'user'. If create_if_missing=True, missing tag rows are created as 'user'.
Returns: {"added": [...], "already_present": [...], "total_tags": [...]} Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
""" """
info = await session.get(AssetInfo, asset_info_id) if not asset_info_row:
if not info: info = await session.get(AssetInfo, asset_info_id)
raise ValueError(f"AssetInfo {asset_info_id} not found") if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags) norm = normalize_tags(tags)
if not norm: if not norm:

View File

@ -36,7 +36,7 @@ from app.user_manager import UserManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from app.api.assets_routes import register_assets_routes from app.api.assets_routes import register_assets_system
from protocol import BinaryEventTypes from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
@ -182,7 +182,7 @@ class PromptServer():
else args.front_end_root else args.front_end_root
) )
logging.info(f"[Prompt Server] web root: {self.web_root}") logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_routes(self.app) register_assets_system(self.app, self.user_manager)
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None