diff --git a/app/database/helpers/__init__.py b/app/database/helpers/__init__.py index 310583607..5a8e87b2c 100644 --- a/app/database/helpers/__init__.py +++ b/app/database/helpers/__init__.py @@ -1,3 +1,4 @@ +from .escape_like import escape_like_prefix from .filters import apply_metadata_filter, apply_tag_filters from .ownership import visible_owner_clause from .projection import is_scalar, project_kv @@ -10,6 +11,7 @@ from .tags import ( __all__ = [ "apply_tag_filters", "apply_metadata_filter", + "escape_like_prefix", "is_scalar", "project_kv", "ensure_tags_exist", diff --git a/app/database/helpers/escape_like.py b/app/database/helpers/escape_like.py new file mode 100644 index 000000000..f905bd40b --- /dev/null +++ b/app/database/helpers/escape_like.py @@ -0,0 +1,7 @@ +def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: + """Escapes %, _ and the escape char itself in a LIKE prefix. + Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like(). + """ + s = s.replace(escape, escape + escape) # escape the escape char first + s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards + return s, escape diff --git a/app/database/services/info.py b/app/database/services/info.py index d2fd1f503..13af76ea0 100644 --- a/app/database/services/info.py +++ b/app/database/services/info.py @@ -13,6 +13,7 @@ from ..helpers import ( apply_metadata_filter, apply_tag_filters, ensure_tags_exist, + escape_like_prefix, project_kv, visible_owner_clause, ) @@ -527,7 +528,8 @@ async def list_tags_with_usage( ) if prefix: - q = q.where(Tag.name.like(prefix.strip().lower() + "%")) + escaped, esc = escape_like_prefix(prefix.strip().lower()) + q = q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) @@ -539,7 +541,8 @@ async def list_tags_with_usage( total_q = select(func.count()).select_from(Tag) if prefix: - total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) + escaped, esc = escape_like_prefix(prefix.strip().lower()) + total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc)) if not include_zero: total_q = total_q.where( Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py index 9ad3c3f86..9bdf770c4 100644 --- a/tests-assets/test_tags.py +++ b/tests-assets/test_tags.py @@ -201,3 +201,28 @@ async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_ba b3 = await r3.json() assert r3.status == 400 assert b3["error"]["code"] == "INVALID_QUERY" + + +@pytest.mark.asyncio +async def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r: + body = await r.json() + assert r.status == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1