mirror of
https://git.datalinker.icu/comfyanonymous/ComfyUI
synced 2025-12-16 01:25:08 +08:00
fix+test: escape "_" symbol in tags filtering
This commit is contained in:
parent
5f187fe6fb
commit
f3cf99d10c
@ -1,3 +1,4 @@
|
|||||||
|
from .escape_like import escape_like_prefix
|
||||||
from .filters import apply_metadata_filter, apply_tag_filters
|
from .filters import apply_metadata_filter, apply_tag_filters
|
||||||
from .ownership import visible_owner_clause
|
from .ownership import visible_owner_clause
|
||||||
from .projection import is_scalar, project_kv
|
from .projection import is_scalar, project_kv
|
||||||
@ -10,6 +11,7 @@ from .tags import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"apply_tag_filters",
|
"apply_tag_filters",
|
||||||
"apply_metadata_filter",
|
"apply_metadata_filter",
|
||||||
|
"escape_like_prefix",
|
||||||
"is_scalar",
|
"is_scalar",
|
||||||
"project_kv",
|
"project_kv",
|
||||||
"ensure_tags_exist",
|
"ensure_tags_exist",
|
||||||
|
|||||||
7
app/database/helpers/escape_like.py
Normal file
7
app/database/helpers/escape_like.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||||
|
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||||
|
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||||
|
"""
|
||||||
|
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||||
|
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||||
|
return s, escape
|
||||||
@ -13,6 +13,7 @@ from ..helpers import (
|
|||||||
apply_metadata_filter,
|
apply_metadata_filter,
|
||||||
apply_tag_filters,
|
apply_tag_filters,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
|
escape_like_prefix,
|
||||||
project_kv,
|
project_kv,
|
||||||
visible_owner_clause,
|
visible_owner_clause,
|
||||||
)
|
)
|
||||||
@ -527,7 +528,8 @@ async def list_tags_with_usage(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prefix:
|
if prefix:
|
||||||
q = q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||||
|
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||||
|
|
||||||
if not include_zero:
|
if not include_zero:
|
||||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||||
@ -539,7 +541,8 @@ async def list_tags_with_usage(
|
|||||||
|
|
||||||
total_q = select(func.count()).select_from(Tag)
|
total_q = select(func.count()).select_from(Tag)
|
||||||
if prefix:
|
if prefix:
|
||||||
total_q = total_q.where(Tag.name.like(prefix.strip().lower() + "%"))
|
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||||
|
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||||
if not include_zero:
|
if not include_zero:
|
||||||
total_q = total_q.where(
|
total_q = total_q.where(
|
||||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||||
|
|||||||
@ -201,3 +201,28 @@ async def test_tags_endpoints_invalid_bodies(http: aiohttp.ClientSession, api_ba
|
|||||||
b3 = await r3.json()
|
b3 = await r3.json()
|
||||||
assert r3.status == 400
|
assert r3.status == 400
|
||||||
assert b3["error"]["code"] == "INVALID_QUERY"
|
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tags_prefix_treats_underscore_literal(
|
||||||
|
http,
|
||||||
|
api_base,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||||
|
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||||
|
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||||
|
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||||
|
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||||
|
|
||||||
|
await asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||||
|
await asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||||
|
|
||||||
|
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}) as r:
|
||||||
|
body = await r.json()
|
||||||
|
assert r.status == 200, body
|
||||||
|
names = [t["name"] for t in body["tags"]]
|
||||||
|
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||||
|
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||||
|
assert body["total"] == 1
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user