[Bugfix] remove fallback in guided_json (int range, patterns) (#16725)

Signed-off-by: csy1204 <josang1204@gmail.com>
Co-authored-by: 조상연[플레이스 AI] <sang-yeon.cho@navercorp.com>
This commit is contained in:
Sangyeon Cho 2025-04-25 15:54:43 +09:00 committed by GitHub
parent b22980a1dc
commit 6aae216b4e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 94 additions and 72 deletions

View File

@ -305,7 +305,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="xgrammar does not support advanced JSON schema features " match="xgrammar does not support advanced JSON schema features "
"like enums, patterns or numeric ranges."): "like string length, item limits, or property bounds."):
llm.generate(prompts="This should fail", llm.generate(prompts="This should fail",
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=True) use_tqdm=True)
@ -386,6 +386,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
jsonschema.validate(instance=output_json, schema=json_schema) jsonschema.validate(instance=output_json, schema=json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
def test_guided_number_range_json_completion(llm,
guided_decoding_backend: str):
sample_output_schema = {
"type": "object",
"properties": {
"age": {
"type": "integer",
"minimum": 18,
"maximum": 99
},
"score": {
"type": "number",
"minimum": 0.0,
"maximum": 100.0
},
"zipcode": {
"type": "string",
"pattern": r"^\d{5}(-\d{4})?$"
},
},
"required": ["age", "score", "zipcode"],
}
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
backend=guided_decoding_backend),
)
outputs = llm.generate(
prompts=[
"Create a JSON object for a user with age, score, and zipcode."
] * 2,
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_output_schema)
assert 18 <= output_json["age"] <= 99
assert 0.0 <= output_json["score"] <= 100.0
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
is not None)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_guidance_no_additional_properties(llm): def test_guidance_no_additional_properties(llm):
schema = { schema = {

View File

@ -47,6 +47,14 @@ def sample_json_schema():
"type": "string", "type": "string",
} }
}, },
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
},
"work_history": { "work_history": {
"type": "array", "type": "array",
"items": { "items": {
@ -56,17 +64,20 @@ def sample_json_schema():
"type": "string" "type": "string"
}, },
"duration": { "duration": {
"type": "number" "type": "number",
"minimum": 0.0,
"maximum": 100.0, # Numeric range
}, },
"position": { "position": {
"type": "string" "type": "string"
} }
}, },
"required": ["company", "position"] "required": ["company", "duration", "position"]
} }
} }
}, },
"required": ["name", "age", "skills", "work_history"] "required":
["name", "age", "skills", "grade", "email", "work_history"]
} }
@ -78,27 +89,18 @@ def unsupported_json_schema():
"properties": { "properties": {
"score": { "score": {
"type": "integer", "type": "integer",
"minimum": 0, "multipleOf": 5 # Numeric multiple
"maximum": 100 # Numeric range
},
"grade": {
"type": "string",
"pattern": "^[A-D]$" # Regex pattern
},
"email": {
"type": "string",
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
}, },
"tags": { "tags": {
"type": "array", "type": "array",
"items": { "items": {
"type": "string", "type": "string",
"pattern": "minLength": 10,
"^[a-z]{1,10}$" # Combining length and pattern restrictions "maxLength": 20
} }
} }
}, },
"required": ["score", "grade", "email", "tags"] "required": ["score", "tags"]
} }

View File

@ -9,10 +9,6 @@ from vllm.v1.structured_output.backend_xgrammar import (
@pytest.fixture @pytest.fixture
def unsupported_string_schemas(): def unsupported_string_schemas():
return [ return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{ {
"type": "string", "type": "string",
"format": "email" "format": "email"
@ -23,22 +19,6 @@ def unsupported_string_schemas():
@pytest.fixture @pytest.fixture
def unsupported_integer_schemas(): def unsupported_integer_schemas():
return [ return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{ {
"type": "integer", "type": "integer",
"multipleOf": 120 "multipleOf": 120
@ -49,22 +29,6 @@ def unsupported_integer_schemas():
@pytest.fixture @pytest.fixture
def unsupported_number_schemas(): def unsupported_number_schemas():
return [ return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{ {
"type": "number", "type": "number",
"multipleOf": 120 "multipleOf": 120
@ -156,13 +120,28 @@ def supported_schema():
"type": "string", "type": "string",
"enum": ["sedan", "suv", "truck"] "enum": ["sedan", "suv", "truck"]
}, },
"car_brand": {
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
"short_description": { "short_description": {
"type": "string", "type": "string",
"maxLength": 50 "maxLength": 50
}, },
"mileage": {
"type": "number",
"minimum": 0,
"maximum": 1000000
},
"model_year": {
"type": "integer",
"exclusiveMinimum": 1900,
"exclusiveMaximum": 2100
},
"long_description": { "long_description": {
"type": "string", "type": "string",
"minLength": 50 "minLength": 50,
"maxLength": 2000
}, },
"address": { "address": {
"type": "object", "type": "object",

View File

@ -65,7 +65,7 @@ def maybe_backend_fallback(
fallback_or_error( fallback_or_error(
guided_params, guided_params,
"xgrammar does not support advanced JSON schema features like " "xgrammar does not support advanced JSON schema features like "
"enums, patterns or numeric ranges.", "outlines") "string length, item limits, or property bounds.", "outlines")
# xgrammar only supports GBNF grammars, so we must convert Lark. # xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that # We must check if the grammar is likely Lark and if that

View File

@ -10,16 +10,8 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
if not isinstance(obj, dict): if not isinstance(obj, dict):
return False return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
key in obj for key in [
"minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf"
]):
return True return True
# Check for array unsupported keywords # Check for array unsupported keywords

View File

@ -179,15 +179,8 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
if not isinstance(obj, dict): if not isinstance(obj, dict):
return False return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for numeric ranges # Check for numeric ranges
if obj.get("type") in ("integer", "number") and any( if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
key in obj
for key in ("minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf")):
return True return True
# Check for array unsupported keywords # Check for array unsupported keywords