diff --git a/app/_assets_helpers.py b/app/_assets_helpers.py index 9fd3600f1..4a8d39625 100644 --- a/app/_assets_helpers.py +++ b/app/_assets_helpers.py @@ -2,8 +2,12 @@ import os from pathlib import Path from typing import Optional, Literal, Sequence +import sqlalchemy as sa + import folder_paths +from .database.models import AssetInfo + def get_comfy_models_folders() -> list[tuple[str, list[str]]]: """Build a list of (folder_name, base_paths[]) categories that are configured for model locations. @@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None: raise ValueError("destination escapes base directory") except Exception: raise ValueError("invalid destination path") + + +def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 2ca2932e5..b5c25dcec 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -8,11 +8,12 @@ from pydantic import ValidationError import folder_paths -from .. import assets_manager, assets_scanner +from .. import assets_manager, assets_scanner, user_manager from . import schemas_in, schemas_out ROUTES = web.RouteTableDef() +UserManager: Optional[user_manager.UserManager] = None @ROUTES.head("/api/assets/hash/{hash}") @@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response: offset=q.offset, sort=q.sort, order=q.order, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(payload.model_dump(mode="json")) @@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response: try: abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download( - asset_info_id=asset_info_id + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve)) @@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, + owner_id=UserManager.get_request_user_id(request), ) if result is None: return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") @@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response: spec, temp_path=tmp_path, client_filename=file_client_name, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(created.model_dump(mode="json"), status=201) except ValueError: @@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - result = await assets_manager.get_asset(asset_info_id=asset_info_id) + result = await assets_manager.get_asset( + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), + ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: @@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response: name=body.name, tags=body.tags, user_metadata=body.user_metadata, + owner_id=UserManager.get_request_user_id(request), ) - except ValueError as ve: + except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.") try: - deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id) + deleted = await assets_manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=UserManager.get_request_user_id(request), + ) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response: offset=query.offset, order=query.order, include_zero=query.include_zero, + owner_id=UserManager.get_request_user_id(request), ) return web.json_response(result.model_dump(mode="json")) @@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response: tags=data.tags, origin="manual", added_by=None, + owner_id=UserManager.get_request_user_id(request), ) - except ValueError as ve: + except (ValueError, PermissionError) as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response: result = await assets_manager.remove_tags_from_asset( asset_info_id=asset_info_id, tags=data.tags, + owner_id=UserManager.get_request_user_id(request), ) except ValueError as ve: return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) @@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response: return web.json_response(states.model_dump(mode="json"), status=200) -def register_assets_routes(app: web.Application) -> None: +def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + global UserManager + UserManager = user_manager_instance app.add_routes(ROUTES) diff --git a/app/assets_manager.py b/app/assets_manager.py index e61895c8a..8cdf1fffc 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -25,9 +25,14 @@ from .database.services import ( get_asset_by_hash, create_asset_info_for_existing_asset, fetch_asset_info_asset_and_tags, + get_asset_info_by_id, ) from .api import schemas_in, schemas_out -from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags +from ._assets_helpers import ( + get_name_and_tags_from_asset_path, + ensure_within_base, + resolve_destination_from_tags, +) from .assets_fetcher import ensure_asset_cached @@ -98,6 +103,7 @@ async def list_assets( offset: int = 0, sort: str = "created_at", order: str = "desc", + owner_id: str = "", ) -> schemas_out.AssetsList: sort = _safe_sort_field(sort) order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower() @@ -105,6 +111,7 @@ async def list_assets( async with await create_session() as session: infos, tag_map, total = await list_asset_infos_page( session, + owner_id=owner_id, include_tags=include_tags, exclude_tags=exclude_tags, name_contains=name_contains, @@ -141,9 +148,9 @@ async def list_assets( ) -async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: +async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail: async with await create_session() as session: - res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id) + res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) if not res: raise ValueError(f"AssetInfo {asset_info_id} not found") info, asset, tag_names = res @@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail: async def resolve_asset_content_for_download( - *, asset_info_id: int + *, + asset_info_id: int, + owner_id: str = "", ) -> tuple[str, str, str]: """ Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time. @@ -173,7 +182,7 @@ async def resolve_asset_content_for_download( ValueError if AssetInfo cannot be found """ async with await create_session() as session: - pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id) + pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) if not pair: raise ValueError(f"AssetInfo {asset_info_id} not found") @@ -198,6 +207,7 @@ async def upload_asset_from_temp_path( *, temp_path: str, client_filename: Optional[str] = None, + owner_id: str = "", ) -> schemas_out.AssetCreated: """ Finalize an uploaded temp file: @@ -250,7 +260,7 @@ async def upload_asset_from_temp_path( mtime_ns=mtime_ns, mime_type=content_type, info_name=os.path.basename(dest_abs), - owner_id="", + owner_id=owner_id, preview_hash=None, user_metadata=spec.user_metadata or {}, tags=spec.tags, @@ -262,7 +272,7 @@ async def upload_asset_from_temp_path( if not info_id: raise RuntimeError("failed to create asset metadata") - pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id)) + pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id) if not pair: raise RuntimeError("inconsistent DB state after ingest") info, asset = pair @@ -290,8 +300,15 @@ async def update_asset( name: Optional[str] = None, tags: Optional[list[str]] = None, user_metadata: Optional[dict] = None, + owner_id: str = "", ) -> schemas_out.AssetUpdated: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + info = await update_asset_info_full( session, asset_info_id=asset_info_id, @@ -300,6 +317,7 @@ async def update_asset( user_metadata=user_metadata, tag_origin="manual", added_by=None, + asset_info_row=info_row, ) tag_names = await get_asset_tags(session, asset_info_id=asset_info_id) @@ -315,9 +333,9 @@ async def update_asset( ) -async def delete_asset_reference(*, asset_info_id: int) -> bool: +async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool: async with await create_session() as session: - r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id) + r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) await session.commit() return r @@ -328,6 +346,7 @@ async def create_asset_from_hash( name: str, tags: Optional[list[str]] = None, user_metadata: Optional[dict] = None, + owner_id: str = "", ) -> Optional[schemas_out.AssetCreated]: canonical = hash_str.strip().lower() async with await create_session() as session: @@ -343,6 +362,7 @@ async def create_asset_from_hash( tags=tags or [], tag_origin="manual", added_by=None, + owner_id=owner_id, ) tag_names = await get_asset_tags(session, asset_info_id=info.id) await session.commit() @@ -369,6 +389,7 @@ async def list_tags( offset: int = 0, order: str = "count_desc", include_zero: bool = True, + owner_id: str = "", ) -> schemas_out.TagsList: limit = max(1, min(1000, limit)) offset = max(0, offset) @@ -381,6 +402,7 @@ async def list_tags( offset=offset, include_zero=include_zero, order=order, + owner_id=owner_id, ) tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] @@ -393,8 +415,14 @@ async def add_tags_to_asset( tags: list[str], origin: str = "manual", added_by: Optional[str] = None, + owner_id: str = "", ) -> schemas_out.TagsAdd: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") data = await add_tags_to_asset_info( session, asset_info_id=asset_info_id, @@ -402,6 +430,7 @@ async def add_tags_to_asset( origin=origin, added_by=added_by, create_if_missing=True, + asset_info_row=info_row, ) await session.commit() return schemas_out.TagsAdd(**data) @@ -411,8 +440,15 @@ async def remove_tags_from_asset( *, asset_info_id: int, tags: list[str], + owner_id: str = "", ) -> schemas_out.TagsRemove: async with await create_session() as session: + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = await remove_tags_from_asset_info( session, asset_info_id=asset_info_id, diff --git a/app/database/services.py b/app/database/services.py index 95a2d07ab..4bf09ed97 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation from .timeutil import utcnow -from .._assets_helpers import normalize_tags +from .._assets_helpers import normalize_tags, visible_owner_clause async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool: @@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option return await session.get(Asset, asset_hash) +async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]: + return await session.get(AssetInfo, asset_info_id) + + async def check_fs_asset_exists_quick( session, *, @@ -317,6 +321,7 @@ async def touch_asset_info_by_id( async def list_asset_infos_page( session: AsyncSession, *, + owner_id: str = "", include_tags: Optional[Sequence[str]] = None, exclude_tags: Optional[Sequence[str]] = None, name_contains: Optional[str] = None, @@ -326,26 +331,18 @@ async def list_asset_infos_page( sort: str = "created_at", order: str = "desc", ) -> tuple[list[AssetInfo], dict[int, list[str]], int]: - """ - Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1), - plus a map of asset_info_id -> [tags], and the total count. - - We purposely collect tags in a separate (single) query to avoid row explosion. - """ - - # Build base query + """Return page of AssetInfo rows in the viewers visibility.""" base = ( select(AssetInfo) .join(Asset, Asset.hash == AssetInfo.asset_hash) .options(contains_eager(AssetInfo.asset)) + .where(visible_owner_clause(owner_id)) ) - # Filters if name_contains: base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) # Sort @@ -368,13 +365,14 @@ async def list_asset_infos_page( select(func.count()) .select_from(AssetInfo) .join(Asset, Asset.hash == AssetInfo.asset_hash) + .where(visible_owner_clause(owner_id)) ) if name_contains: count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) - total = (await session.execute(count_stmt)).scalar_one() + total = int((await session.execute(count_stmt)).scalar_one() or 0) # Fetch rows infos = (await session.execute(base)).scalars().unique().all() @@ -394,13 +392,22 @@ async def list_asset_infos_page( return infos, tag_map, total -async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]: - row = await session.execute( +async def fetch_asset_info_and_asset( + session: AsyncSession, + *, + asset_info_id: int, + owner_id: str = "", +) -> Optional[tuple[AssetInfo, Asset]]: + stmt = ( select(AssetInfo, Asset) .join(Asset, Asset.hash == AssetInfo.asset_hash) - .where(AssetInfo.id == asset_info_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) .limit(1) ) + row = await session.execute(stmt) pair = row.first() if not pair: return None @@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags( session: AsyncSession, *, asset_info_id: int, + owner_id: str = "", ) -> Optional[tuple[AssetInfo, Asset, list[str]]]: - """Fetch AssetInfo, its Asset, and all tag names. - - Returns: - (AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist. - """ stmt = ( select(AssetInfo, Asset, Tag.name) .join(Asset, Asset.hash == AssetInfo.asset_hash) .join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True) .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) - .where(AssetInfo.id == asset_info_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) .options(noload(AssetInfo.tags)) .order_by(Tag.name.asc()) ) @@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset( tags: Optional[Sequence[str]] = None, tag_origin: str = "manual", added_by: Optional[str] = None, + owner_id: str = "", ) -> AssetInfo: """Create a new AssetInfo referencing an existing Asset (no content write).""" now = utcnow() info = AssetInfo( - owner_id="", + owner_id=owner_id, name=name, asset_hash=asset_hash, preview_hash=None, @@ -593,6 +600,7 @@ async def update_asset_info_full( user_metadata: Optional[dict] = None, tag_origin: str = "manual", added_by: Optional[str] = None, + asset_info_row: Any = None, ) -> AssetInfo: """ Update AssetInfo fields: @@ -601,9 +609,12 @@ async def update_asset_info_full( - replace tags with provided set (if provided) Returns the updated AssetInfo. """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row touched = False if name is not None and name != info.name: @@ -633,9 +644,12 @@ async def update_asset_info_full( return info -async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool: +async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool: """Delete the user-visible AssetInfo row. Cascades clear tags and metadata.""" - res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id)) + res = await session.execute(delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + )) return bool(res.rowcount) @@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s async def list_tags_with_usage( - session, + session: AsyncSession, *, prefix: Optional[str] = None, limit: int = 100, offset: int = 0, include_zero: bool = True, - order: str = "count_desc", # "count_desc" | "name_asc" + order: str = "count_desc", # "count_desc" | "name_asc" + owner_id: str = "", ) -> tuple[list[tuple[str, str, int]], int]: - """ - Returns: - rows: list of (name, tag_type, count) - total: number of tags matching filter (independent of pagination) - """ - # Subquery with counts by tag_name + # Subquery with counts by tag_name and owner_id counts_sq = ( select( AssetInfoTag.tag_name.label("tag_name"), func.count(AssetInfoTag.asset_info_id).label("cnt"), ) + .select_from(AssetInfoTag) + .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) + .where(visible_owner_clause(owner_id)) .group_by(AssetInfoTag.tag_name) .subquery() ) @@ -765,14 +778,16 @@ async def add_tags_to_asset_info( origin: str = "manual", added_by: Optional[str] = None, create_if_missing: bool = True, + asset_info_row: Any = None, ) -> dict: """Adds tags to an AssetInfo. If create_if_missing=True, missing tag rows are created as 'user'. Returns: {"added": [...], "already_present": [...], "total_tags": [...]} """ - info = await session.get(AssetInfo, asset_info_id) - if not info: - raise ValueError(f"AssetInfo {asset_info_id} not found") + if not asset_info_row: + info = await session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") norm = normalize_tags(tags) if not norm: diff --git a/server.py b/server.py index ba368654f..310b7601c 100644 --- a/server.py +++ b/server.py @@ -36,7 +36,7 @@ from app.user_manager import UserManager from app.custom_node_manager import CustomNodeManager from typing import Optional, Union from api_server.routes.internal.internal_routes import InternalRoutes -from app.api.assets_routes import register_assets_routes +from app.api.assets_routes import register_assets_system from protocol import BinaryEventTypes async def send_socket_catch_exception(function, message): @@ -182,7 +182,7 @@ class PromptServer(): else args.front_end_root ) logging.info(f"[Prompt Server] web root: {self.web_root}") - register_assets_routes(self.app) + register_assets_system(self.app, self.user_manager) routes = web.RouteTableDef() self.routes = routes self.last_node_id = None