diff --git a/app/database/helpers/tags.py b/app/database/helpers/tags.py index 479343096..c8f9e5074 100644 --- a/app/database/helpers/tags.py +++ b/app/database/helpers/tags.py @@ -1,6 +1,8 @@ from typing import Iterable from sqlalchemy import delete, select +from sqlalchemy.dialects import postgresql as d_pg +from sqlalchemy.dialects import sqlite as d_sqlite from sqlalchemy.ext.asyncio import AsyncSession from ..._assets_helpers import normalize_tags @@ -13,13 +15,29 @@ async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_typ if not wanted: return [] existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() + existing_names = {t.name for t in existing} + missing = [n for n in wanted if n not in existing_names] + if missing: + dialect = session.bind.dialect.name + rows = [{"name": n, "tag_type": tag_type} for n in missing] + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + await session.execute(ins) + existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all() by_name = {t.name: t for t in existing} - to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name] - if to_create: - session.add_all(to_create) - await session.flush() - by_name.update({t.name: t for t in to_create}) - return [by_name[n] for n in wanted] + return [by_name[n] for n in wanted if n in by_name] async def add_missing_tag_for_asset_id( diff --git a/app/database/services/content.py b/app/database/services/content.py index f8c43abfd..8388e524d 100644 --- a/app/database/services/content.py +++ b/app/database/services/content.py @@ -484,6 +484,7 @@ async def ingest_fs_asset( """ locator = os.path.abspath(abs_path) now = utcnow() + dialect = session.bind.dialect.name if preview_id: if not await session.get(Asset, preview_id): @@ -502,10 +503,34 @@ async def ingest_fs_asset( await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) ).scalars().first() if not asset: - async with session.begin_nested(): - asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now) - session.add(asset) - await session.flush() + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + if dialect == "sqlite": + ins = ( + d_sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + elif dialect == "postgresql": + ins = ( + d_pg.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + else: + raise NotImplementedError(f"Unsupported database dialect: {dialect}") + res = await session.execute(ins) + rowcount = int(res.rowcount or 0) + asset = ( + await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") + if rowcount > 0: out["asset_created"] = True else: changed = False @@ -524,7 +549,6 @@ async def ingest_fs_asset( "file_path": locator, "mtime_ns": int(mtime_ns), } - dialect = session.bind.dialect.name if dialect == "sqlite": ins = ( d_sqlite.insert(AssetCacheState) diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py index 3bfb62ca4..c1a962d26 100644 --- a/tests-assets/test_uploads.py +++ b/tests-assets/test_uploads.py @@ -1,4 +1,7 @@ +import asyncio import json +import uuid + import aiohttp import pytest @@ -125,6 +128,54 @@ async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSessio assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) +@pytest.mark.asyncio +@pytest.mark.parametrize("root", ["input", "output"]) +async def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: aiohttp.ClientSession, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _form(name: str) -> aiohttp.FormData: + f = aiohttp.FormData() + f.add_field("file", data, filename=name, content_type="application/octet-stream") + f.add_field("tags", json.dumps(tags)) + f.add_field("name", name) + f.add_field("user_metadata", json.dumps({})) + return f + + r1, r2 = await asyncio.gather( + http.post(api_base + "/api/assets", data=_form(name1)), + http.post(api_base + "/api/assets", data=_form(name2)), + ) + b1, b2 = await r1.json(), await r2.json() + assert r1.status in (200, 201), b1 + assert r2.status in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + async with http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + ) as rl: + bl = await rl.json() + assert rl.status == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + @pytest.mark.asyncio async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): payload = {