mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-12 15:35:00 +08:00
231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
import os
|
|
import uuid
|
|
from typing import Iterable, Sequence
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.dialects import postgresql as d_pg
|
|
from sqlalchemy.dialects import sqlite as d_sqlite
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
|
|
from ..timeutil import utcnow
|
|
|
|
MAX_BIND_PARAMS = 800
|
|
|
|
|
|
async def seed_from_paths_batch(
|
|
session: AsyncSession,
|
|
*,
|
|
specs: Sequence[dict],
|
|
owner_id: str = "",
|
|
) -> dict:
|
|
"""Each spec is a dict with keys:
|
|
- abs_path: str
|
|
- size_bytes: int
|
|
- mtime_ns: int
|
|
- info_name: str
|
|
- tags: list[str]
|
|
- fname: Optional[str]
|
|
"""
|
|
if not specs:
|
|
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
|
|
|
now = utcnow()
|
|
dialect = session.bind.dialect.name
|
|
if dialect not in ("sqlite", "postgresql"):
|
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
|
|
|
asset_rows: list[dict] = []
|
|
state_rows: list[dict] = []
|
|
path_to_asset: dict[str, str] = {}
|
|
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
|
path_list: list[str] = []
|
|
|
|
for sp in specs:
|
|
ap = os.path.abspath(sp["abs_path"])
|
|
aid = str(uuid.uuid4())
|
|
iid = str(uuid.uuid4())
|
|
path_list.append(ap)
|
|
path_to_asset[ap] = aid
|
|
|
|
asset_rows.append(
|
|
{
|
|
"id": aid,
|
|
"hash": None,
|
|
"size_bytes": sp["size_bytes"],
|
|
"mime_type": None,
|
|
"created_at": now,
|
|
}
|
|
)
|
|
state_rows.append(
|
|
{
|
|
"asset_id": aid,
|
|
"file_path": ap,
|
|
"mtime_ns": sp["mtime_ns"],
|
|
}
|
|
)
|
|
asset_to_info[aid] = {
|
|
"id": iid,
|
|
"owner_id": owner_id,
|
|
"name": sp["info_name"],
|
|
"asset_id": aid,
|
|
"preview_id": None,
|
|
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"last_access_time": now,
|
|
"_tags": sp["tags"],
|
|
"_filename": sp["fname"],
|
|
}
|
|
|
|
# insert all seed Assets (hash=NULL)
|
|
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
|
|
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
|
await session.execute(ins_asset, chunk)
|
|
|
|
# try to claim AssetCacheState (file_path)
|
|
winners_by_path: set[str] = set()
|
|
if dialect == "sqlite":
|
|
ins_state = (
|
|
d_sqlite.insert(AssetCacheState)
|
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
|
.returning(AssetCacheState.file_path)
|
|
)
|
|
else:
|
|
ins_state = (
|
|
d_pg.insert(AssetCacheState)
|
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
|
.returning(AssetCacheState.file_path)
|
|
)
|
|
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
|
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
|
|
|
|
all_paths_set = set(path_list)
|
|
losers_by_path = all_paths_set - winners_by_path
|
|
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
|
if lost_assets: # losers get their Asset removed
|
|
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
|
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
|
|
|
|
if not winners_by_path:
|
|
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
|
|
|
# insert AssetInfo only for winners
|
|
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
|
if dialect == "sqlite":
|
|
ins_info = (
|
|
d_sqlite.insert(AssetInfo)
|
|
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
|
.returning(AssetInfo.id)
|
|
)
|
|
else:
|
|
ins_info = (
|
|
d_pg.insert(AssetInfo)
|
|
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
|
.returning(AssetInfo.id)
|
|
)
|
|
|
|
inserted_info_ids: set[str] = set()
|
|
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
|
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
|
|
|
|
# build and insert tag + meta rows for the AssetInfo
|
|
tag_rows: list[dict] = []
|
|
meta_rows: list[dict] = []
|
|
if inserted_info_ids:
|
|
for row in winner_info_rows:
|
|
iid = row["id"]
|
|
if iid not in inserted_info_ids:
|
|
continue
|
|
for t in row["_tags"]:
|
|
tag_rows.append({
|
|
"asset_info_id": iid,
|
|
"tag_name": t,
|
|
"origin": "automatic",
|
|
"added_at": now,
|
|
})
|
|
if row["_filename"]:
|
|
meta_rows.append(
|
|
{
|
|
"asset_info_id": iid,
|
|
"key": "filename",
|
|
"ordinal": 0,
|
|
"val_str": row["_filename"],
|
|
"val_num": None,
|
|
"val_bool": None,
|
|
"val_json": None,
|
|
}
|
|
)
|
|
|
|
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
|
return {
|
|
"inserted_infos": len(inserted_info_ids),
|
|
"won_states": len(winners_by_path),
|
|
"lost_states": len(losers_by_path),
|
|
}
|
|
|
|
|
|
async def bulk_insert_tags_and_meta(
|
|
session: AsyncSession,
|
|
*,
|
|
tag_rows: list[dict],
|
|
meta_rows: list[dict],
|
|
max_bind_params: int,
|
|
) -> None:
|
|
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
|
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
|
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
|
"""
|
|
dialect = session.bind.dialect.name
|
|
if tag_rows:
|
|
if dialect == "sqlite":
|
|
ins_links = (
|
|
d_sqlite.insert(AssetInfoTag)
|
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
|
)
|
|
elif dialect == "postgresql":
|
|
ins_links = (
|
|
d_pg.insert(AssetInfoTag)
|
|
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
|
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
|
await session.execute(ins_links, chunk)
|
|
if meta_rows:
|
|
if dialect == "sqlite":
|
|
ins_meta = (
|
|
d_sqlite.insert(AssetInfoMeta)
|
|
.on_conflict_do_nothing(
|
|
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
|
)
|
|
)
|
|
elif dialect == "postgresql":
|
|
ins_meta = (
|
|
d_pg.insert(AssetInfoMeta)
|
|
.on_conflict_do_nothing(
|
|
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
|
)
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
|
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
|
await session.execute(ins_meta, chunk)
|
|
|
|
|
|
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
|
if not rows:
|
|
return []
|
|
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
|
for i in range(0, len(rows), rows_per_stmt):
|
|
yield rows[i:i + rows_per_stmt]
|
|
|
|
|
|
def _iter_chunks(seq, n: int):
|
|
for i in range(0, len(seq), n):
|
|
yield seq[i:i + n]
|
|
|
|
|
|
def _rows_per_stmt(cols: int) -> int:
|
|
return max(1, MAX_BIND_PARAMS // max(1, cols))
|