mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-15 00:44:25 +08:00
add "--multi-user" support
This commit is contained in:
parent
7c1b0be496
commit
026b7f209c
@ -2,8 +2,12 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, Sequence
|
from typing import Optional, Literal, Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
|
from .database.models import AssetInfo
|
||||||
|
|
||||||
|
|
||||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||||
@ -133,3 +137,11 @@ def ensure_within_base(candidate: str, base: str) -> None:
|
|||||||
raise ValueError("destination escapes base directory")
|
raise ValueError("destination escapes base directory")
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("invalid destination path")
|
raise ValueError("invalid destination path")
|
||||||
|
|
||||||
|
|
||||||
|
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||||
|
"""Build owner visibility predicate for reads."""
|
||||||
|
owner_id = (owner_id or "").strip()
|
||||||
|
if owner_id == "":
|
||||||
|
return AssetInfo.owner_id == ""
|
||||||
|
return AssetInfo.owner_id.in_(["", owner_id])
|
||||||
|
|||||||
@ -8,11 +8,12 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
|
||||||
from .. import assets_manager, assets_scanner
|
from .. import assets_manager, assets_scanner, user_manager
|
||||||
from . import schemas_in, schemas_out
|
from . import schemas_in, schemas_out
|
||||||
|
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
|
UserManager: Optional[user_manager.UserManager] = None
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.head("/api/assets/hash/{hash}")
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
@ -45,6 +46,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
offset=q.offset,
|
offset=q.offset,
|
||||||
sort=q.sort,
|
sort=q.sort,
|
||||||
order=q.order,
|
order=q.order,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json"))
|
||||||
|
|
||||||
@ -63,7 +65,8 @@ async def download_asset_content(request: web.Request) -> web.Response:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
|
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
|
||||||
asset_info_id=asset_info_id
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||||
@ -96,6 +99,7 @@ async def create_asset_from_hash(request: web.Request) -> web.Response:
|
|||||||
name=body.name,
|
name=body.name,
|
||||||
tags=body.tags,
|
tags=body.tags,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||||
@ -186,6 +190,7 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
spec,
|
spec,
|
||||||
temp_path=tmp_path,
|
temp_path=tmp_path,
|
||||||
client_filename=file_client_name,
|
client_filename=file_client_name,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(created.model_dump(mode="json"), status=201)
|
return web.json_response(created.model_dump(mode="json"), status=201)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@ -207,7 +212,10 @@ async def get_asset(request: web.Request) -> web.Response:
|
|||||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await assets_manager.get_asset(asset_info_id=asset_info_id)
|
result = await assets_manager.get_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
|
)
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
except Exception:
|
except Exception:
|
||||||
@ -236,8 +244,9 @@ async def update_asset(request: web.Request) -> web.Response:
|
|||||||
name=body.name,
|
name=body.name,
|
||||||
tags=body.tags,
|
tags=body.tags,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
except ValueError as ve:
|
except (ValueError, PermissionError) as ve:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
except Exception:
|
except Exception:
|
||||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
@ -253,7 +262,10 @@ async def delete_asset(request: web.Request) -> web.Response:
|
|||||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id)
|
deleted = await assets_manager.delete_asset_reference(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
@ -280,6 +292,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
offset=query.offset,
|
offset=query.offset,
|
||||||
order=query.order,
|
order=query.order,
|
||||||
include_zero=query.include_zero,
|
include_zero=query.include_zero,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(result.model_dump(mode="json"))
|
return web.json_response(result.model_dump(mode="json"))
|
||||||
|
|
||||||
@ -306,8 +319,9 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
|||||||
tags=data.tags,
|
tags=data.tags,
|
||||||
origin="manual",
|
origin="manual",
|
||||||
added_by=None,
|
added_by=None,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
except ValueError as ve:
|
except (ValueError, PermissionError) as ve:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
except Exception:
|
except Exception:
|
||||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
@ -335,6 +349,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
|||||||
result = await assets_manager.remove_tags_from_asset(
|
result = await assets_manager.remove_tags_from_asset(
|
||||||
asset_info_id=asset_info_id,
|
asset_info_id=asset_info_id,
|
||||||
tags=data.tags,
|
tags=data.tags,
|
||||||
|
owner_id=UserManager.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
@ -370,7 +385,9 @@ async def get_asset_scan_status(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(states.model_dump(mode="json"), status=200)
|
return web.json_response(states.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
def register_assets_routes(app: web.Application) -> None:
|
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||||
|
global UserManager
|
||||||
|
UserManager = user_manager_instance
|
||||||
app.add_routes(ROUTES)
|
app.add_routes(ROUTES)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,9 +25,14 @@ from .database.services import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
create_asset_info_for_existing_asset,
|
create_asset_info_for_existing_asset,
|
||||||
fetch_asset_info_asset_and_tags,
|
fetch_asset_info_asset_and_tags,
|
||||||
|
get_asset_info_by_id,
|
||||||
)
|
)
|
||||||
from .api import schemas_in, schemas_out
|
from .api import schemas_in, schemas_out
|
||||||
from ._assets_helpers import get_name_and_tags_from_asset_path, ensure_within_base, resolve_destination_from_tags
|
from ._assets_helpers import (
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
|
ensure_within_base,
|
||||||
|
resolve_destination_from_tags,
|
||||||
|
)
|
||||||
from .assets_fetcher import ensure_asset_cached
|
from .assets_fetcher import ensure_asset_cached
|
||||||
|
|
||||||
|
|
||||||
@ -98,6 +103,7 @@ async def list_assets(
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
sort: str = "created_at",
|
sort: str = "created_at",
|
||||||
order: str = "desc",
|
order: str = "desc",
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.AssetsList:
|
) -> schemas_out.AssetsList:
|
||||||
sort = _safe_sort_field(sort)
|
sort = _safe_sort_field(sort)
|
||||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||||
@ -105,6 +111,7 @@ async def list_assets(
|
|||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
infos, tag_map, total = await list_asset_infos_page(
|
infos, tag_map, total = await list_asset_infos_page(
|
||||||
session,
|
session,
|
||||||
|
owner_id=owner_id,
|
||||||
include_tags=include_tags,
|
include_tags=include_tags,
|
||||||
exclude_tags=exclude_tags,
|
exclude_tags=exclude_tags,
|
||||||
name_contains=name_contains,
|
name_contains=name_contains,
|
||||||
@ -141,9 +148,9 @@ async def list_assets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
|
async def get_asset(*, asset_info_id: int, owner_id: str = "") -> schemas_out.AssetDetail:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id)
|
res = await fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not res:
|
if not res:
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
info, asset, tag_names = res
|
info, asset, tag_names = res
|
||||||
@ -163,7 +170,9 @@ async def get_asset(*, asset_info_id: int) -> schemas_out.AssetDetail:
|
|||||||
|
|
||||||
|
|
||||||
async def resolve_asset_content_for_download(
|
async def resolve_asset_content_for_download(
|
||||||
*, asset_info_id: int
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
owner_id: str = "",
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
Returns (abs_path, content_type, download_name) for the given AssetInfo id and touches last_access_time.
|
||||||
@ -173,7 +182,7 @@ async def resolve_asset_content_for_download(
|
|||||||
ValueError if AssetInfo cannot be found
|
ValueError if AssetInfo cannot be found
|
||||||
"""
|
"""
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
|
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not pair:
|
if not pair:
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
@ -198,6 +207,7 @@ async def upload_asset_from_temp_path(
|
|||||||
*,
|
*,
|
||||||
temp_path: str,
|
temp_path: str,
|
||||||
client_filename: Optional[str] = None,
|
client_filename: Optional[str] = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.AssetCreated:
|
) -> schemas_out.AssetCreated:
|
||||||
"""
|
"""
|
||||||
Finalize an uploaded temp file:
|
Finalize an uploaded temp file:
|
||||||
@ -250,7 +260,7 @@ async def upload_asset_from_temp_path(
|
|||||||
mtime_ns=mtime_ns,
|
mtime_ns=mtime_ns,
|
||||||
mime_type=content_type,
|
mime_type=content_type,
|
||||||
info_name=os.path.basename(dest_abs),
|
info_name=os.path.basename(dest_abs),
|
||||||
owner_id="",
|
owner_id=owner_id,
|
||||||
preview_hash=None,
|
preview_hash=None,
|
||||||
user_metadata=spec.user_metadata or {},
|
user_metadata=spec.user_metadata or {},
|
||||||
tags=spec.tags,
|
tags=spec.tags,
|
||||||
@ -262,7 +272,7 @@ async def upload_asset_from_temp_path(
|
|||||||
if not info_id:
|
if not info_id:
|
||||||
raise RuntimeError("failed to create asset metadata")
|
raise RuntimeError("failed to create asset metadata")
|
||||||
|
|
||||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id))
|
pair = await fetch_asset_info_and_asset(session, asset_info_id=int(info_id), owner_id=owner_id)
|
||||||
if not pair:
|
if not pair:
|
||||||
raise RuntimeError("inconsistent DB state after ingest")
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
info, asset = pair
|
info, asset = pair
|
||||||
@ -290,8 +300,15 @@ async def update_asset(
|
|||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
user_metadata: Optional[dict] = None,
|
user_metadata: Optional[dict] = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.AssetUpdated:
|
) -> schemas_out.AssetUpdated:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
|
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
info = await update_asset_info_full(
|
info = await update_asset_info_full(
|
||||||
session,
|
session,
|
||||||
asset_info_id=asset_info_id,
|
asset_info_id=asset_info_id,
|
||||||
@ -300,6 +317,7 @@ async def update_asset(
|
|||||||
user_metadata=user_metadata,
|
user_metadata=user_metadata,
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
added_by=None,
|
added_by=None,
|
||||||
|
asset_info_row=info_row,
|
||||||
)
|
)
|
||||||
|
|
||||||
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
tag_names = await get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
@ -315,9 +333,9 @@ async def update_asset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def delete_asset_reference(*, asset_info_id: int) -> bool:
|
async def delete_asset_reference(*, asset_info_id: int, owner_id: str) -> bool:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id)
|
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return r
|
return r
|
||||||
|
|
||||||
@ -328,6 +346,7 @@ async def create_asset_from_hash(
|
|||||||
name: str,
|
name: str,
|
||||||
tags: Optional[list[str]] = None,
|
tags: Optional[list[str]] = None,
|
||||||
user_metadata: Optional[dict] = None,
|
user_metadata: Optional[dict] = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> Optional[schemas_out.AssetCreated]:
|
) -> Optional[schemas_out.AssetCreated]:
|
||||||
canonical = hash_str.strip().lower()
|
canonical = hash_str.strip().lower()
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
@ -343,6 +362,7 @@ async def create_asset_from_hash(
|
|||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
added_by=None,
|
added_by=None,
|
||||||
|
owner_id=owner_id,
|
||||||
)
|
)
|
||||||
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
tag_names = await get_asset_tags(session, asset_info_id=info.id)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@ -369,6 +389,7 @@ async def list_tags(
|
|||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
order: str = "count_desc",
|
order: str = "count_desc",
|
||||||
include_zero: bool = True,
|
include_zero: bool = True,
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.TagsList:
|
) -> schemas_out.TagsList:
|
||||||
limit = max(1, min(1000, limit))
|
limit = max(1, min(1000, limit))
|
||||||
offset = max(0, offset)
|
offset = max(0, offset)
|
||||||
@ -381,6 +402,7 @@ async def list_tags(
|
|||||||
offset=offset,
|
offset=offset,
|
||||||
include_zero=include_zero,
|
include_zero=include_zero,
|
||||||
order=order,
|
order=order,
|
||||||
|
owner_id=owner_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||||
@ -393,8 +415,14 @@ async def add_tags_to_asset(
|
|||||||
tags: list[str],
|
tags: list[str],
|
||||||
origin: str = "manual",
|
origin: str = "manual",
|
||||||
added_by: Optional[str] = None,
|
added_by: Optional[str] = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.TagsAdd:
|
) -> schemas_out.TagsAdd:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
|
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
data = await add_tags_to_asset_info(
|
data = await add_tags_to_asset_info(
|
||||||
session,
|
session,
|
||||||
asset_info_id=asset_info_id,
|
asset_info_id=asset_info_id,
|
||||||
@ -402,6 +430,7 @@ async def add_tags_to_asset(
|
|||||||
origin=origin,
|
origin=origin,
|
||||||
added_by=added_by,
|
added_by=added_by,
|
||||||
create_if_missing=True,
|
create_if_missing=True,
|
||||||
|
asset_info_row=info_row,
|
||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return schemas_out.TagsAdd(**data)
|
return schemas_out.TagsAdd(**data)
|
||||||
@ -411,8 +440,15 @@ async def remove_tags_from_asset(
|
|||||||
*,
|
*,
|
||||||
asset_info_id: int,
|
asset_info_id: int,
|
||||||
tags: list[str],
|
tags: list[str],
|
||||||
|
owner_id: str = "",
|
||||||
) -> schemas_out.TagsRemove:
|
) -> schemas_out.TagsRemove:
|
||||||
async with await create_session() as session:
|
async with await create_session() as session:
|
||||||
|
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
data = await remove_tags_from_asset_info(
|
data = await remove_tags_from_asset_info(
|
||||||
session,
|
session,
|
||||||
asset_info_id=asset_info_id,
|
asset_info_id=asset_info_id,
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
|
|
||||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
from .models import Asset, AssetInfo, AssetInfoTag, AssetCacheState, Tag, AssetInfoMeta, AssetLocation
|
||||||
from .timeutil import utcnow
|
from .timeutil import utcnow
|
||||||
from .._assets_helpers import normalize_tags
|
from .._assets_helpers import normalize_tags, visible_owner_clause
|
||||||
|
|
||||||
|
|
||||||
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
async def asset_exists_by_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||||
@ -30,6 +30,10 @@ async def get_asset_by_hash(session: AsyncSession, *, asset_hash: str) -> Option
|
|||||||
return await session.get(Asset, asset_hash)
|
return await session.get(Asset, asset_hash)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> Optional[AssetInfo]:
|
||||||
|
return await session.get(AssetInfo, asset_info_id)
|
||||||
|
|
||||||
|
|
||||||
async def check_fs_asset_exists_quick(
|
async def check_fs_asset_exists_quick(
|
||||||
session,
|
session,
|
||||||
*,
|
*,
|
||||||
@ -317,6 +321,7 @@ async def touch_asset_info_by_id(
|
|||||||
async def list_asset_infos_page(
|
async def list_asset_infos_page(
|
||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
|
owner_id: str = "",
|
||||||
include_tags: Optional[Sequence[str]] = None,
|
include_tags: Optional[Sequence[str]] = None,
|
||||||
exclude_tags: Optional[Sequence[str]] = None,
|
exclude_tags: Optional[Sequence[str]] = None,
|
||||||
name_contains: Optional[str] = None,
|
name_contains: Optional[str] = None,
|
||||||
@ -326,26 +331,18 @@ async def list_asset_infos_page(
|
|||||||
sort: str = "created_at",
|
sort: str = "created_at",
|
||||||
order: str = "desc",
|
order: str = "desc",
|
||||||
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
) -> tuple[list[AssetInfo], dict[int, list[str]], int]:
|
||||||
"""
|
"""Return page of AssetInfo rows in the viewers visibility."""
|
||||||
Returns a page of AssetInfo rows with their Asset eagerly loaded (no N+1),
|
|
||||||
plus a map of asset_info_id -> [tags], and the total count.
|
|
||||||
|
|
||||||
We purposely collect tags in a separate (single) query to avoid row explosion.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Build base query
|
|
||||||
base = (
|
base = (
|
||||||
select(AssetInfo)
|
select(AssetInfo)
|
||||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
.options(contains_eager(AssetInfo.asset))
|
.options(contains_eager(AssetInfo.asset))
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Filters
|
|
||||||
if name_contains:
|
if name_contains:
|
||||||
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
|
|
||||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
|
|
||||||
base = _apply_metadata_filter(base, metadata_filter)
|
base = _apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
# Sort
|
# Sort
|
||||||
@ -368,13 +365,14 @@ async def list_asset_infos_page(
|
|||||||
select(func.count())
|
select(func.count())
|
||||||
.select_from(AssetInfo)
|
.select_from(AssetInfo)
|
||||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
)
|
)
|
||||||
if name_contains:
|
if name_contains:
|
||||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%"))
|
||||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
total = (await session.execute(count_stmt)).scalar_one()
|
total = int((await session.execute(count_stmt)).scalar_one() or 0)
|
||||||
|
|
||||||
# Fetch rows
|
# Fetch rows
|
||||||
infos = (await session.execute(base)).scalars().unique().all()
|
infos = (await session.execute(base)).scalars().unique().all()
|
||||||
@ -394,13 +392,22 @@ async def list_asset_infos_page(
|
|||||||
return infos, tag_map, total
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
|
||||||
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
|
async def fetch_asset_info_and_asset(
|
||||||
row = await session.execute(
|
session: AsyncSession,
|
||||||
|
*,
|
||||||
|
asset_info_id: int,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> Optional[tuple[AssetInfo, Asset]]:
|
||||||
|
stmt = (
|
||||||
select(AssetInfo, Asset)
|
select(AssetInfo, Asset)
|
||||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
.where(AssetInfo.id == asset_info_id)
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
row = await session.execute(stmt)
|
||||||
pair = row.first()
|
pair = row.first()
|
||||||
if not pair:
|
if not pair:
|
||||||
return None
|
return None
|
||||||
@ -411,18 +418,17 @@ async def fetch_asset_info_asset_and_tags(
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
asset_info_id: int,
|
asset_info_id: int,
|
||||||
|
owner_id: str = "",
|
||||||
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
) -> Optional[tuple[AssetInfo, Asset, list[str]]]:
|
||||||
"""Fetch AssetInfo, its Asset, and all tag names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(AssetInfo, Asset, [tag_names]) or None if the asset_info_id does not exist.
|
|
||||||
"""
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(AssetInfo, Asset, Tag.name)
|
select(AssetInfo, Asset, Tag.name)
|
||||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||||
.where(AssetInfo.id == asset_info_id)
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
.options(noload(AssetInfo.tags))
|
.options(noload(AssetInfo.tags))
|
||||||
.order_by(Tag.name.asc())
|
.order_by(Tag.name.asc())
|
||||||
)
|
)
|
||||||
@ -511,11 +517,12 @@ async def create_asset_info_for_existing_asset(
|
|||||||
tags: Optional[Sequence[str]] = None,
|
tags: Optional[Sequence[str]] = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
added_by: Optional[str] = None,
|
added_by: Optional[str] = None,
|
||||||
|
owner_id: str = "",
|
||||||
) -> AssetInfo:
|
) -> AssetInfo:
|
||||||
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
info = AssetInfo(
|
info = AssetInfo(
|
||||||
owner_id="",
|
owner_id=owner_id,
|
||||||
name=name,
|
name=name,
|
||||||
asset_hash=asset_hash,
|
asset_hash=asset_hash,
|
||||||
preview_hash=None,
|
preview_hash=None,
|
||||||
@ -593,6 +600,7 @@ async def update_asset_info_full(
|
|||||||
user_metadata: Optional[dict] = None,
|
user_metadata: Optional[dict] = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
added_by: Optional[str] = None,
|
added_by: Optional[str] = None,
|
||||||
|
asset_info_row: Any = None,
|
||||||
) -> AssetInfo:
|
) -> AssetInfo:
|
||||||
"""
|
"""
|
||||||
Update AssetInfo fields:
|
Update AssetInfo fields:
|
||||||
@ -601,9 +609,12 @@ async def update_asset_info_full(
|
|||||||
- replace tags with provided set (if provided)
|
- replace tags with provided set (if provided)
|
||||||
Returns the updated AssetInfo.
|
Returns the updated AssetInfo.
|
||||||
"""
|
"""
|
||||||
info = await session.get(AssetInfo, asset_info_id)
|
if not asset_info_row:
|
||||||
if not info:
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
else:
|
||||||
|
info = asset_info_row
|
||||||
|
|
||||||
touched = False
|
touched = False
|
||||||
if name is not None and name != info.name:
|
if name is not None and name != info.name:
|
||||||
@ -633,9 +644,12 @@ async def update_asset_info_full(
|
|||||||
return info
|
return info
|
||||||
|
|
||||||
|
|
||||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
|
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int, owner_id: str) -> bool:
|
||||||
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
||||||
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
|
res = await session.execute(delete(AssetInfo).where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
))
|
||||||
return bool(res.rowcount)
|
return bool(res.rowcount)
|
||||||
|
|
||||||
|
|
||||||
@ -691,25 +705,24 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
|
|||||||
|
|
||||||
|
|
||||||
async def list_tags_with_usage(
|
async def list_tags_with_usage(
|
||||||
session,
|
session: AsyncSession,
|
||||||
*,
|
*,
|
||||||
prefix: Optional[str] = None,
|
prefix: Optional[str] = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
include_zero: bool = True,
|
include_zero: bool = True,
|
||||||
order: str = "count_desc", # "count_desc" | "name_asc"
|
order: str = "count_desc", # "count_desc" | "name_asc"
|
||||||
|
owner_id: str = "",
|
||||||
) -> tuple[list[tuple[str, str, int]], int]:
|
) -> tuple[list[tuple[str, str, int]], int]:
|
||||||
"""
|
# Subquery with counts by tag_name and owner_id
|
||||||
Returns:
|
|
||||||
rows: list of (name, tag_type, count)
|
|
||||||
total: number of tags matching filter (independent of pagination)
|
|
||||||
"""
|
|
||||||
# Subquery with counts by tag_name
|
|
||||||
counts_sq = (
|
counts_sq = (
|
||||||
select(
|
select(
|
||||||
AssetInfoTag.tag_name.label("tag_name"),
|
AssetInfoTag.tag_name.label("tag_name"),
|
||||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||||
)
|
)
|
||||||
|
.select_from(AssetInfoTag)
|
||||||
|
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||||
|
.where(visible_owner_clause(owner_id))
|
||||||
.group_by(AssetInfoTag.tag_name)
|
.group_by(AssetInfoTag.tag_name)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
@ -765,14 +778,16 @@ async def add_tags_to_asset_info(
|
|||||||
origin: str = "manual",
|
origin: str = "manual",
|
||||||
added_by: Optional[str] = None,
|
added_by: Optional[str] = None,
|
||||||
create_if_missing: bool = True,
|
create_if_missing: bool = True,
|
||||||
|
asset_info_row: Any = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Adds tags to an AssetInfo.
|
"""Adds tags to an AssetInfo.
|
||||||
If create_if_missing=True, missing tag rows are created as 'user'.
|
If create_if_missing=True, missing tag rows are created as 'user'.
|
||||||
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
Returns: {"added": [...], "already_present": [...], "total_tags": [...]}
|
||||||
"""
|
"""
|
||||||
info = await session.get(AssetInfo, asset_info_id)
|
if not asset_info_row:
|
||||||
if not info:
|
info = await session.get(AssetInfo, asset_info_id)
|
||||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
norm = normalize_tags(tags)
|
norm = normalize_tags(tags)
|
||||||
if not norm:
|
if not norm:
|
||||||
|
|||||||
@ -36,7 +36,7 @@ from app.user_manager import UserManager
|
|||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||||
from app.api.assets_routes import register_assets_routes
|
from app.api.assets_routes import register_assets_system
|
||||||
from protocol import BinaryEventTypes
|
from protocol import BinaryEventTypes
|
||||||
|
|
||||||
async def send_socket_catch_exception(function, message):
|
async def send_socket_catch_exception(function, message):
|
||||||
@ -182,7 +182,7 @@ class PromptServer():
|
|||||||
else args.front_end_root
|
else args.front_end_root
|
||||||
)
|
)
|
||||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||||
register_assets_routes(self.app)
|
register_assets_system(self.app, self.user_manager)
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user