mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 04:44:57 +08:00
[Bugfix] Fallback to outlines for complex json schemas (#10899)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
7883c2bbe7
commit
8d370e91cb
@ -69,6 +69,37 @@ def sample_json_schema():
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_complex_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"score": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"maximum": 100 # Numeric range
|
||||
},
|
||||
"grade": {
|
||||
"type": "string",
|
||||
"pattern": "^[A-D]$" # Regex pattern
|
||||
},
|
||||
"email": {
|
||||
"type": "string",
|
||||
"pattern": "^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$"
|
||||
},
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"pattern":
|
||||
"^[a-z]{1,10}$" # Combining length and pattern restrictions
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["score", "grade", "email", "tags"]
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_guided_choice():
|
||||
return [
|
||||
|
||||
@ -76,6 +76,34 @@ def test_guided_json_completion(sample_json_schema, llm):
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_complex_json_completion(sample_complex_json_schema, llm):
|
||||
sampling_params = SamplingParams(
|
||||
temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(json=sample_complex_json_schema))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an assignment grade "
|
||||
f"that fits this schema: {sample_complex_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json,
|
||||
schema=sample_complex_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_guided_choice_completion(sample_guided_choice, llm):
|
||||
sampling_params = SamplingParams(
|
||||
|
||||
@ -15,6 +15,40 @@ if TYPE_CHECKING:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj for key in [
|
||||
"minimum", "maximum", "exclusiveMinimum",
|
||||
"exclusiveMaximum", "multipleOf"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def maybe_backend_fallback(
|
||||
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
||||
# lm-format-enforce doesn't support grammar, fallback to xgrammar
|
||||
@ -47,6 +81,15 @@ def maybe_backend_fallback(
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
# xgrammar doesn't support some JSON schema features
|
||||
elif (guided_params.json is not None
|
||||
and has_xgrammar_unsupported_json_features(guided_params.json)):
|
||||
logger.warning(
|
||||
"xgrammar does not support advanced JSON schema features like "
|
||||
"patterns or numeric ranges. "
|
||||
"Falling back to use outlines instead.")
|
||||
guided_params.backend = "outlines"
|
||||
|
||||
return guided_params
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user