mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-15 04:35:01 +08:00
[Bugfix] Fallback to outlines for complex json schemas (#10899)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
7883c2bbe7
commit
8d370e91cb
@ -69,6 +69,37 @@ def sample_json_schema():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_complex_json_schema():
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"score": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 100 # Numeric range
|
||||||
|
},
|
||||||
|
"grade": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[A-D]$" # Regex pattern
|
||||||
|
},
|
||||||
|
"email": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
||||||
|
},
|
||||||
|
"tags": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"pattern":
|
||||||
|
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["score", "grade", "email", "tags"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_guided_choice():
|
def sample_guided_choice():
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm):
|
|||||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
def test_guided_complex_json_completion(sample_complex_json_schema, llm):
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema))
|
||||||
|
outputs = llm.generate(prompts=[
|
||||||
|
f"Give an example JSON for an assignment grade "
|
||||||
|
f"that fits this schema: {sample_complex_json_schema}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json,
|
||||||
|
schema=sample_complex_json_schema)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_guided_choice_completion(sample_guided_choice, llm):
|
def test_guided_choice_completion(sample_guided_choice, llm):
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
|
|||||||
@ -15,6 +15,40 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||||
|
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||||
|
|
||||||
|
def check_object(obj: dict) -> bool:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for pattern restrictions
|
||||||
|
if "pattern" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for numeric ranges
|
||||||
|
if obj.get("type") in ("integer", "number") and any(
|
||||||
|
key in obj for key in [
|
||||||
|
"minimum", "maximum", "exclusiveMinimum",
|
||||||
|
"exclusiveMaximum", "multipleOf"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Recursively check all nested objects and arrays
|
||||||
|
for value in obj.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if check_object(value):
|
||||||
|
return True
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and check_object(item):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return check_object(schema)
|
||||||
|
|
||||||
|
|
||||||
def maybe_backend_fallback(
|
def maybe_backend_fallback(
|
||||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||||
@ -47,6 +81,15 @@ def maybe_backend_fallback(
|
|||||||
"Falling back to use outlines instead.")
|
"Falling back to use outlines instead.")
|
||||||
guided_params.backend = "outlines"
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
|
# xgrammar doesn't support some JSON schema features
|
||||||
|
elif (guided_params.json is not None
|
||||||
|
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||||
|
logger.warning(
|
||||||
|
"xgrammar does not support advanced JSON schema features like "
|
||||||
|
"patterns or numeric ranges. "
|
||||||
|
"Falling back to use outlines instead.")
|
||||||
|
guided_params.backend = "outlines"
|
||||||
|
|
||||||
return guided_params
|
return guided_params
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user