mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-14 16:34:36 +08:00
add "--multi-user" support
This commit is contained in:
parent
7c1b0be496
commit
026b7f209c
@ -2,8 +2,12 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .database.models import AssetInfo
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
@ -8,11 +8,12 @@ from pydantic import ValidationError
|
||||
|
||||
import folder_paths
|
||||
|
||||
from .. import assets_manager, assets_scanner
|
||||
from .. import assets_manager, assets_scanner, user_manager
|
||||
from . import schemas_in, schemas_out
|
||||
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
UserManager: Optional[user_manager.UserManager] = None
|
||||
|
||||
|
||||
@ROUTES.head("/api/assets/hash/{hash}")
|
||||
@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
offset=q.offset,
|
||||
sort=q.sort,
|
||||
order=q.order,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
|
||||
asset_info_id=asset_info_id
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
if result is None:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||
@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
spec,
|
||||
temp_path=tmp_path,
|
||||
client_filename=file_client_name,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(created.model_dump(mode="json"), status=201)
|
||||
except ValueError:
|
||||
@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
result = await assets_manager.get_asset(asset_info_id=asset_info_id)
|
||||
result = await assets_manager.get_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
name=body.name,
|
||||
tags=body.tags,
|
||||
user_metadata=body.user_metadata,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id)
|
||||
deleted = await assets_manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
offset=query.offset,
|
||||
order=query.order,
|
||||
include_zero=query.include_zero,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
return web.json_response(result.model_dump(mode="json"))
|
||||
|
||||
@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
||||
tags=data.tags,
|
||||
origin="manual",
|
||||
added_by=None,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
except (ValueError, PermissionError) as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||
result = await assets_manager.remove_tags_from_asset(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=data.tags,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||
@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
|
||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
def register_assets_routes(app: web.Application) -> None:
|
||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||
global UserManager
|
||||
UserManager = user_manager_instance
|
||||
app.add_routes(ROUTES)
|
||||
|
||||
|
||||
|
||||
@ -25,9 +25,14 @@ from .database.services import (
|
||||
get_asset_by_hash,
|
||||
create_asset_info_for_existing_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_info_by_id,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
|
||||
from ._assets_helpers import (
|
||||
get_name_and_tags_from_asset_path,
|
||||
ensure_within_base,
|
||||
resolve_destination_from_tags,
|
||||
)
|
||||
from .assets_fetcher import ensure_asset_cached
|
||||
|
||||
|
||||
@ -98,6 +103,7 @@ async def list_assets(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
@ -105,6 +111,7 @@ async def list_assets(
|
||||
async with await create_session() as session:
|
||||
infos, tag_map, total = await list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
@ -141,9 +148,9 @@ async def list_assets(
|
||||
)
|
||||
|
||||
|
||||
async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
|
||||
async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||
async with await create_session() as session:
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id)
|
||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
|
||||
|
||||
|
||||
async def resolve_asset_content_for_download(
|
||||
*, asset_info_id: int
|
||||
*,
|
||||
asset_info_id: int,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
||||
@ -173,7 +182,7 @@ async def resolve_asset_content_for_download(
|
||||
ValueError if AssetInfo cannot be found
|
||||
"""
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
@ -198,6 +207,7 @@ async def upload_asset_from_temp_path(
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Finalize an uploaded temp file:
|
||||
@ -250,7 +260,7 @@ async def upload_asset_from_temp_path(
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=os.path.basename(dest_abs),
|
||||
owner_id="",
|
||||
owner_id=owner_id,
|
||||
preview_hash=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
@ -262,7 +272,7 @@ async def upload_asset_from_temp_path(
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id))
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
@ -290,8 +300,15 @@ async def update_asset(
|
||||
name: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = await update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
@ -300,6 +317,7 @@ async def update_asset(
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
added_by=None,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
@ -315,9 +333,9 @@ async def update_asset(
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: int) -> bool:
|
||||
async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool:
|
||||
async with await create_session() as session:
|
||||
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
await session.commit()
|
||||
return r
|
||||
|
||||
@ -328,6 +346,7 @@ async def create_asset_from_hash(
|
||||
name: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
user_metadata: Optional[dict] = None,
|
||||
owner_id: str = "",
|
||||
) -> Optional[schemas_out.AssetCreated]:
|
||||
canonical = hash_str.strip().lower()
|
||||
async with await create_session() as session:
|
||||
@ -343,6 +362,7 @@ async def create_asset_from_hash(
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
added_by=None,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||
await session.commit()
|
||||
@ -369,6 +389,7 @@ async def list_tags(
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
@ -381,6 +402,7 @@ async def list_tags(
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
@ -393,8 +415,14 @@ async def add_tags_to_asset(
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = await add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
@ -402,6 +430,7 @@ async def add_tags_to_asset(
|
||||
origin=origin,
|
||||
added_by=added_by,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
await session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
@ -411,8 +440,15 @@ async def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: int,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
async with await create_session() as session:
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = await remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
|
||||
@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
||||
from .timeutil import utcnow
|
||||
from .._assets_helpers import normalize_tags
|
||||
from .._assets_helpers import normalize_tags, visible_owner_clause
|
||||
|
||||
|
||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option
|
||||
return await session.get(Asset, asset_hash)
|
||||
|
||||
|
||||
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]:
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
@ -317,6 +321,7 @@ async def touch_asset_info_by_id(
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
owner_id: str = "",
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
@ -326,26 +331,18 @@ async def list_asset_infos_page(
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||
"""
|
||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
||||
plus a map of asset_info_id -> [tags], and the total count.
|
||||
|
||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
||||
"""
|
||||
|
||||
# Build base query
|
||||
"""Return page of AssetInfo rows in the viewers visibility."""
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.options(contains_eager(AssetInfo.asset))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
# Filters
|
||||
if name_contains:
|
||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
|
||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||
|
||||
base = _apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
# Sort
|
||||
@ -368,13 +365,14 @@ async def list_asset_infos_page(
|
||||
select(func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = (await session.execute(count_stmt)).scalar_one()
|
||||
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
# Fetch rows
|
||||
infos = (await session.execute(base)).scalars().unique().all()
|
||||
@ -394,13 +392,22 @@ async def list_asset_infos_page(
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
row = await session.execute(
|
||||
async def fetch_asset_info_and_asset(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
row = await session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
owner_id: str = "",
|
||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||
"""Fetch AssetInfo, its Asset, and all tag names.
|
||||
|
||||
Returns:
|
||||
(AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist.
|
||||
"""
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset(
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
||||
now = utcnow()
|
||||
info = AssetInfo(
|
||||
owner_id="",
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_hash=asset_hash,
|
||||
preview_hash=None,
|
||||
@ -593,6 +600,7 @@ async def update_asset_info_full(
|
||||
user_metadata: Optional[dict] = None,
|
||||
tag_origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
"""
|
||||
Update AssetInfo fields:
|
||||
@ -601,9 +609,12 @@ async def update_asset_info_full(
|
||||
- replace tags with provided set (if provided)
|
||||
Returns the updated AssetInfo.
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
@ -633,9 +644,12 @@ async def update_asset_info_full(
|
||||
return info
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool:
|
||||
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
||||
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
|
||||
res = await session.execute(delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
))
|
||||
return bool(res.rowcount)
|
||||
|
||||
|
||||
@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
|
||||
|
||||
|
||||
async def list_tags_with_usage(
|
||||
session,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
"""
|
||||
Returns:
|
||||
rows: list of (name, tag_type, count)
|
||||
total: number of tags matching filter (independent of pagination)
|
||||
"""
|
||||
# Subquery with counts by tag_name
|
||||
# Subquery with counts by tag_name and owner_id
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
@ -765,14 +778,16 @@ async def add_tags_to_asset_info(
|
||||
origin: str = "manual",
|
||||
added_by: Optional[str] = None,
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
"""Adds tags to an AssetInfo.
|
||||
If create_if_missing=True, missing tag rows are created as 'user'.
|
||||
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
||||
"""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if not asset_info_row:
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
|
||||
@ -36,7 +36,7 @@ from app.user_manager import UserManager
|
||||
from app.custom_node_manager import CustomNodeManager
|
||||
from typing import Optional, Union
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from app.api.assets_routes import register_assets_routes
|
||||
from app.api.assets_routes import register_assets_system
|
||||
from protocol import BinaryEventTypes
|
||||
|
||||
async def send_socket_catch_exception(function, message):
|
||||
@ -182,7 +182,7 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_routes(self.app)
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user