mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 08:44:58 +08:00
[Bugfix] remove fallback in guided_json (int range, patterns) (#16725)
Signed-off-by: csy1204 <josang1204@gmail.com> Co-authored-by: 조상연[플레이스 AI] <sang-yeon.cho@navercorp.com>
This commit is contained in:
parent
b22980a1dc
commit
6aae216b4e
@ -305,7 +305,7 @@ def test_disable_guided_decoding_fallback(sample_regex, llm):
|
|||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="xgrammar does not support advanced JSON schema features "
|
match="xgrammar does not support advanced JSON schema features "
|
||||||
"like enums, patterns or numeric ranges."):
|
"like string length, item limits, or property bounds."):
|
||||||
llm.generate(prompts="This should fail",
|
llm.generate(prompts="This should fail",
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
@ -386,6 +386,62 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str):
|
|||||||
jsonschema.validate(instance=output_json, schema=json_schema)
|
jsonschema.validate(instance=output_json, schema=json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
|
||||||
|
def test_guided_number_range_json_completion(llm,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
sample_output_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 18,
|
||||||
|
"maximum": 99
|
||||||
|
},
|
||||||
|
"score": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 100.0
|
||||||
|
},
|
||||||
|
"zipcode": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": r"^\d{5}(-\d{4})?$"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["age", "score", "zipcode"],
|
||||||
|
}
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=sample_output_schema,
|
||||||
|
backend=guided_decoding_backend),
|
||||||
|
)
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=[
|
||||||
|
"Create a JSON object for a user with age, score, and zipcode."
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json, schema=sample_output_schema)
|
||||||
|
assert 18 <= output_json["age"] <= 99
|
||||||
|
assert 0.0 <= output_json["score"] <= 100.0
|
||||||
|
assert (re.fullmatch(r"^\d{5}(-\d{4})?$", output_json["zipcode"])
|
||||||
|
is not None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_guidance_no_additional_properties(llm):
|
def test_guidance_no_additional_properties(llm):
|
||||||
schema = {
|
schema = {
|
||||||
|
|||||||
@ -47,6 +47,14 @@ def sample_json_schema():
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"grade": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[A-D]$" # Regex pattern
|
||||||
|
},
|
||||||
|
"email": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
||||||
|
},
|
||||||
"work_history": {
|
"work_history": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
@ -56,17 +64,20 @@ def sample_json_schema():
|
|||||||
"type": "string"
|
"type": "string"
|
||||||
},
|
},
|
||||||
"duration": {
|
"duration": {
|
||||||
"type": "number"
|
"type": "number",
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 100.0, # Numeric range
|
||||||
},
|
},
|
||||||
"position": {
|
"position": {
|
||||||
"type": "string"
|
"type": "string"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["company", "position"]
|
"required": ["company", "duration", "position"]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["name", "age", "skills", "work_history"]
|
"required":
|
||||||
|
["name", "age", "skills", "grade", "email", "work_history"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -78,27 +89,18 @@ def unsupported_json_schema():
|
|||||||
"properties": {
|
"properties": {
|
||||||
"score": {
|
"score": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 0,
|
"multipleOf": 5 # Numeric multiple
|
||||||
"maximum": 100 # Numeric range
|
|
||||||
},
|
|
||||||
"grade": {
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^[A-D]$" # Regex pattern
|
|
||||||
},
|
|
||||||
"email": {
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
|
||||||
},
|
},
|
||||||
"tags": {
|
"tags": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"pattern":
|
"minLength": 10,
|
||||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
"maxLength": 20
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["score", "grade", "email", "tags"]
|
"required": ["score", "tags"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -9,10 +9,6 @@ from vllm.v1.structured_output.backend_xgrammar import (
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def unsupported_string_schemas():
|
def unsupported_string_schemas():
|
||||||
return [
|
return [
|
||||||
{
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^[a-zA-Z]+$"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"format": "email"
|
"format": "email"
|
||||||
@ -23,22 +19,6 @@ def unsupported_string_schemas():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def unsupported_integer_schemas():
|
def unsupported_integer_schemas():
|
||||||
return [
|
return [
|
||||||
{
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer",
|
|
||||||
"maximum": 120
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer",
|
|
||||||
"exclusiveMinimum": 120
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "integer",
|
|
||||||
"exclusiveMaximum": 120
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"multipleOf": 120
|
"multipleOf": 120
|
||||||
@ -49,22 +29,6 @@ def unsupported_integer_schemas():
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def unsupported_number_schemas():
|
def unsupported_number_schemas():
|
||||||
return [
|
return [
|
||||||
{
|
|
||||||
"type": "number",
|
|
||||||
"minimum": 0
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number",
|
|
||||||
"maximum": 120
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number",
|
|
||||||
"exclusiveMinimum": 120
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "number",
|
|
||||||
"exclusiveMaximum": 120
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"type": "number",
|
"type": "number",
|
||||||
"multipleOf": 120
|
"multipleOf": 120
|
||||||
@ -156,13 +120,28 @@ def supported_schema():
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["sedan", "suv", "truck"]
|
"enum": ["sedan", "suv", "truck"]
|
||||||
},
|
},
|
||||||
|
"car_brand": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[a-zA-Z]+$"
|
||||||
|
},
|
||||||
"short_description": {
|
"short_description": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"maxLength": 50
|
"maxLength": 50
|
||||||
},
|
},
|
||||||
|
"mileage": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 1000000
|
||||||
|
},
|
||||||
|
"model_year": {
|
||||||
|
"type": "integer",
|
||||||
|
"exclusiveMinimum": 1900,
|
||||||
|
"exclusiveMaximum": 2100
|
||||||
|
},
|
||||||
"long_description": {
|
"long_description": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"minLength": 50
|
"minLength": 50,
|
||||||
|
"maxLength": 2000
|
||||||
},
|
},
|
||||||
"address": {
|
"address": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
|
|||||||
@ -65,7 +65,7 @@ def maybe_backend_fallback(
|
|||||||
fallback_or_error(
|
fallback_or_error(
|
||||||
guided_params,
|
guided_params,
|
||||||
"xgrammar does not support advanced JSON schema features like "
|
"xgrammar does not support advanced JSON schema features like "
|
||||||
"enums, patterns or numeric ranges.", "outlines")
|
"string length, item limits, or property bounds.", "outlines")
|
||||||
|
|
||||||
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
# xgrammar only supports GBNF grammars, so we must convert Lark.
|
||||||
# We must check if the grammar is likely Lark and if that
|
# We must check if the grammar is likely Lark and if that
|
||||||
|
|||||||
@ -10,16 +10,8 @@ def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
|||||||
if not isinstance(obj, dict):
|
if not isinstance(obj, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check for pattern restrictions
|
|
||||||
if "pattern" in obj:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for numeric ranges
|
# Check for numeric ranges
|
||||||
if obj.get("type") in ("integer", "number") and any(
|
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||||
key in obj for key in [
|
|
||||||
"minimum", "maximum", "exclusiveMinimum",
|
|
||||||
"exclusiveMaximum", "multipleOf"
|
|
||||||
]):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for array unsupported keywords
|
# Check for array unsupported keywords
|
||||||
|
|||||||
@ -179,15 +179,8 @@ def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
|||||||
if not isinstance(obj, dict):
|
if not isinstance(obj, dict):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Check for pattern restrictions
|
|
||||||
if "pattern" in obj:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for numeric ranges
|
# Check for numeric ranges
|
||||||
if obj.get("type") in ("integer", "number") and any(
|
if obj.get("type") in ("integer", "number") and ("multipleOf" in obj):
|
||||||
key in obj
|
|
||||||
for key in ("minimum", "maximum", "exclusiveMinimum",
|
|
||||||
"exclusiveMaximum", "multipleOf")):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check for array unsupported keywords
|
# Check for array unsupported keywords
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user