fix+test: escape "_" symbol in tags filtering

This commit is contained in:
bigcat88 2025-09-15 17:29:27 +03:00
parent 5f187fe6fb
commit f3cf99d10c
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
4 changed files with 39 additions and 2 deletions

View File

@ -1,3 +1,4 @@
from .escape_like import escape_like_prefix
from .filters import apply_metadata_filter, apply_tag_filters from .filters import apply_metadata_filter, apply_tag_filters
from .ownership import visible_owner_clause from .ownership import visible_owner_clause
from .projection import is_scalar, project_kv from .projection import is_scalar, project_kv
@ -10,6 +11,7 @@ from .tags import (
__all__ = [ __all__ = [
"apply_tag_filters", "apply_tag_filters",
"apply_metadata_filter", "apply_metadata_filter",
"escape_like_prefix",
"is_scalar", "is_scalar",
"project_kv", "project_kv",
"ensure_tags_exist", "ensure_tags_exist",

View File

@ -0,0 +1,7 @@
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape

View File

@ -13,6 +13,7 @@ from ..helpers import (
apply_metadata_filter, apply_metadata_filter,
apply_tag_filters, apply_tag_filters,
ensure_tags_exist, ensure_tags_exist,
escape_like_prefix,
project_kv, project_kv,
visible_owner_clause, visible_owner_clause,
) )
@ -527,7 +528,8 @@ async def list_tags_with_usage(
) )
if prefix: if prefix:
q = q.where(Tag.name.like(prefix.strip().lower() + "%")) escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero: if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0) q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
@ -539,7 +541,8 @@ async def list_tags_with_usage(
total_q = select(func.count()).select_from(Tag) total_q = select(func.count()).select_from(Tag)
if prefix: if prefix:
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%")) escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero: if not include_zero:
total_q = total_q.where( total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name)) Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))

View File

@ -201,3 +201,28 @@ async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_ba
b3 = await r3.json() b3 = await r3.json()
assert r3.status == 400 assert r3.status == 400
assert b3["error"]["code"] == "INVALID_QUERY" assert b3["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_tags_prefix_treats_underscore_literal(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
base = f"pref_{uuid.uuid4().hex[:6]}"
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r:
body = await r.json()
assert r.status == 200, body
names = [t["name"] for t in body["tags"]]
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
assert body["total"] == 1