add "--multi-user" support

This commit is contained in:
bigcat88 2025-08-27 19:47:55 +03:00
parent 7c1b0be496
commit 026b7f209c
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
5 changed files with 136 additions and 56 deletions

View File

@ -2,8 +2,12 @@ import os
from pathlib import Path
from typing import Optional, Literal, Sequence
import sqlalchemy as sa
import folder_paths
from .database.models import AssetInfo
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])

View File

@ -8,11 +8,12 @@ from pydantic import ValidationError
import folder_paths
from .. import assets_manager, assets_scanner
from .. import assets_manager, assets_scanner, user_manager
from . import schemas_in, schemas_out
ROUTES = web.RouteTableDef()
UserManager: Optional[user_manager.UserManager] = None
@ROUTES.head("/api/assets/hash/{hash}")
@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response:
offset=q.offset,
sort=q.sort,
order=q.order,
owner_id=UserManager.get_request_user_id(request),
)
return web.json_response(payload.model_dump(mode="json"))
@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response:
try:
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
asset_info_id=asset_info_id
asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=UserManager.get_request_user_id(request),
)
if result is None:
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response:
spec,
temp_path=tmp_path,
client_filename=file_client_name,
owner_id=UserManager.get_request_user_id(request),
)
return web.json_response(created.model_dump(mode="json"), status=201)
except ValueError:
@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
result = await assets_manager.get_asset(asset_info_id=asset_info_id)
result = await assets_manager.get_asset(
asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response:
name=body.name,
tags=body.tags,
user_metadata=body.user_metadata,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve:
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id)
deleted = await assets_manager.delete_asset_reference(
asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request),
)
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response:
offset=query.offset,
order=query.order,
include_zero=query.include_zero,
owner_id=UserManager.get_request_user_id(request),
)
return web.json_response(result.model_dump(mode="json"))
@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response:
tags=data.tags,
origin="manual",
added_by=None,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve:
except (ValueError, PermissionError) as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
result = await assets_manager.remove_tags_from_asset(
asset_info_id=asset_info_id,
tags=data.tags,
owner_id=UserManager.get_request_user_id(request),
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
return web.json_response(states.model_dump(mode="json"), status=200)
def register_assets_routes(app: web.Application) -> None:
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
global UserManager
UserManager = user_manager_instance
app.add_routes(ROUTES)

View File

@ -25,9 +25,14 @@ from .database.services import (
get_asset_by_hash,
create_asset_info_for_existing_asset,
fetch_asset_info_asset_and_tags,
get_asset_info_by_id,
)
from .api import schemas_in, schemas_out
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
from ._assets_helpers import (
get_name_and_tags_from_asset_path,
ensure_within_base,
resolve_destination_from_tags,
)
from .assets_fetcher import ensure_asset_cached
@ -98,6 +103,7 @@ async def list_assets(
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
@ -105,6 +111,7 @@ async def list_assets(
async with await create_session() as session:
infos, tag_map, total = await list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
@ -141,9 +148,9 @@ async def list_assets(
)
async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail:
async with await create_session() as session:
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id)
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
async def resolve_asset_content_for_download(
*, asset_info_id: int
*,
asset_info_id: int,
owner_id: str = "",
) -> tuple[str, str, str]:
"""
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
@ -173,7 +182,7 @@ async def resolve_asset_content_for_download(
ValueError if AssetInfo cannot be found
"""
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
@ -198,6 +207,7 @@ async def upload_asset_from_temp_path(
*,
temp_path: str,
client_filename: Optional[str] = None,
owner_id: str = "",
) -> schemas_out.AssetCreated:
"""
Finalize an uploaded temp file:
@ -250,7 +260,7 @@ async def upload_asset_from_temp_path(
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=os.path.basename(dest_abs),
owner_id="",
owner_id=owner_id,
preview_hash=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
@ -262,7 +272,7 @@ async def upload_asset_from_temp_path(
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id))
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
@ -290,8 +300,15 @@ async def update_asset(
name: Optional[str] = None,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = await update_asset_info_full(
session,
asset_info_id=asset_info_id,
@ -300,6 +317,7 @@ async def update_asset(
user_metadata=user_metadata,
tag_origin="manual",
added_by=None,
asset_info_row=info_row,
)
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
@ -315,9 +333,9 @@ async def update_asset(
)
async def delete_asset_reference(*, asset_info_id: int) -> bool:
async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool:
async with await create_session() as session:
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id)
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
await session.commit()
return r
@ -328,6 +346,7 @@ async def create_asset_from_hash(
name: str,
tags: Optional[list[str]] = None,
user_metadata: Optional[dict] = None,
owner_id: str = "",
) -> Optional[schemas_out.AssetCreated]:
canonical = hash_str.strip().lower()
async with await create_session() as session:
@ -343,6 +362,7 @@ async def create_asset_from_hash(
tags=tags or [],
tag_origin="manual",
added_by=None,
owner_id=owner_id,
)
tag_names = await get_asset_tags(session, asset_info_id=info.id)
await session.commit()
@ -369,6 +389,7 @@ async def list_tags(
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
@ -381,6 +402,7 @@ async def list_tags(
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
@ -393,8 +415,14 @@ async def add_tags_to_asset(
tags: list[str],
origin: str = "manual",
added_by: Optional[str] = None,
owner_id: str = "",
) -> schemas_out.TagsAdd:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
@ -402,6 +430,7 @@ async def add_tags_to_asset(
origin=origin,
added_by=added_by,
create_if_missing=True,
asset_info_row=info_row,
)
await session.commit()
return schemas_out.TagsAdd(**data)
@ -411,8 +440,15 @@ async def remove_tags_from_asset(
*,
asset_info_id: int,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
async with await create_session() as session:
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = await remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,

View File

@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
from .timeutil import utcnow
from .._assets_helpers import normalize_tags
from .._assets_helpers import normalize_tags, visible_owner_clause
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option
return await session.get(Asset, asset_hash)
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]:
return await session.get(AssetInfo, asset_info_id)
async def check_fs_asset_exists_quick(
session,
*,
@ -317,6 +321,7 @@ async def touch_asset_info_by_id(
async def list_asset_infos_page(
session: AsyncSession,
*,
owner_id: str = "",
include_tags: Optional[Sequence[str]] = None,
exclude_tags: Optional[Sequence[str]] = None,
name_contains: Optional[str] = None,
@ -326,26 +331,18 @@ async def list_asset_infos_page(
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
"""
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
plus a map of asset_info_id -> [tags], and the total count.
We purposely collect tags in a separate (single) query to avoid row explosion.
"""
# Build base query
"""Return page of AssetInfo rows in the viewers visibility."""
base = (
select(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.options(contains_eager(AssetInfo.asset))
.where(visible_owner_clause(owner_id))
)
# Filters
if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
@ -368,13 +365,14 @@ async def list_asset_infos_page(
select(func.count())
.select_from(AssetInfo)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.where(visible_owner_clause(owner_id))
)
if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
total = (await session.execute(count_stmt)).scalar_one()
total = int((await session.execute(count_stmt)).scalar_one() or 0)
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
@ -394,13 +392,22 @@ async def list_asset_infos_page(
return infos, tag_map, total
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
row = await session.execute(
async def fetch_asset_info_and_asset(
session: AsyncSession,
*,
asset_info_id: int,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset]]:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.where(AssetInfo.id == asset_info_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
)
row = await session.execute(stmt)
pair = row.first()
if not pair:
return None
@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags(
session: AsyncSession,
*,
asset_info_id: int,
owner_id: str = "",
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
"""Fetch AssetInfo, its Asset, and all tag names.
Returns:
(AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist.
"""
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(AssetInfo.id == asset_info_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset(
tags: Optional[Sequence[str]] = None,
tag_origin: str = "manual",
added_by: Optional[str] = None,
owner_id: str = "",
) -> AssetInfo:
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
now = utcnow()
info = AssetInfo(
owner_id="",
owner_id=owner_id,
name=name,
asset_hash=asset_hash,
preview_hash=None,
@ -593,6 +600,7 @@ async def update_asset_info_full(
user_metadata: Optional[dict] = None,
tag_origin: str = "manual",
added_by: Optional[str] = None,
asset_info_row: Any = None,
) -> AssetInfo:
"""
Update AssetInfo fields:
@ -601,9 +609,12 @@ async def update_asset_info_full(
- replace tags with provided set (if provided)
Returns the updated AssetInfo.
"""
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
@ -633,9 +644,12 @@ async def update_asset_info_full(
return info
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool:
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
res = await session.execute(delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
))
return bool(res.rowcount)
@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
async def list_tags_with_usage(
session,
session: AsyncSession,
*,
prefix: Optional[str] = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc", # "count_desc" | "name_asc"
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
"""
Returns:
rows: list of (name, tag_type, count)
total: number of tags matching filter (independent of pagination)
"""
# Subquery with counts by tag_name
# Subquery with counts by tag_name and owner_id
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
@ -765,11 +778,13 @@ async def add_tags_to_asset_info(
origin: str = "manual",
added_by: Optional[str] = None,
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
"""Adds tags to an AssetInfo.
If create_if_missing=True, missing tag rows are created as 'user'.
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
"""
if not asset_info_row:
info = await session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")

View File

@ -36,7 +36,7 @@ from app.user_manager import UserManager
from app.custom_node_manager import CustomNodeManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from app.api.assets_routes import register_assets_routes
from app.api.assets_routes import register_assets_system
from protocol import BinaryEventTypes
async def send_socket_catch_exception(function, message):
@ -182,7 +182,7 @@ class PromptServer():
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_routes(self.app)
register_assets_system(self.app, self.user_manager)
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None