[CI] Make JSON output tests less likely to fail (#17859)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-05-12 18:31:54 -04:00 committed by GitHub
parent 2b0db9b0e2
commit ebab1ac37c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 11 deletions

View File

@ -72,12 +72,14 @@ def sample_json_schema():
"type": "string"
}
},
"required": ["company", "duration", "position"]
"required": ["company", "duration", "position"],
"additionalProperties": False
}
}
},
"required":
["name", "age", "skills", "grade", "email", "work_history"]
["name", "age", "skills", "grade", "email", "work_history"],
"additionalProperties": False
}
@ -100,7 +102,8 @@ def unsupported_json_schema():
}
}
},
"required": ["score", "tags"]
"required": ["score", "tags"],
"additionalProperties": False
}
@ -139,7 +142,8 @@ def sample_definition_json_schema():
},
'required': ['steps', 'final_answer'],
'title': 'MathReasoning',
'type': 'object'
'type': 'object',
"additionalProperties": False
}

View File

@ -62,6 +62,16 @@ class CarDescription(BaseModel):
car_type: CarType
def _load_json(s: str, backend: str) -> str:
if backend != "xgrammar":
return json.loads(s)
# xgrammar specific workarounds
# https://github.com/mlc-ai/xgrammar/issues/286
s = re.sub(r'[\x00-\x1F\x7F-\xFF]', '', s)
return json.loads(s)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize(
"model_name, guided_decoding_backend, tokenizer_mode, speculative_config",
@ -102,7 +112,7 @@ def test_structured_output(
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
outputs = llm.generate(prompts=[
(f"Give an example JSON for an employee profile that fits this "
@ -131,7 +141,7 @@ def test_structured_output(
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=100,
max_tokens=4096,
n=2,
guided_decoding=GuidedDecodingParams(json_object=True))
@ -161,7 +171,7 @@ def test_structured_output(
#
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
if guided_decoding_backend.startswith("xgrammar"):
with pytest.raises(ValueError,
@ -376,12 +386,13 @@ def test_structured_output(
"minLength": min_length
}
},
"required": ["description"]
"required": ["description"],
"additionalProperties": False
}
sampling_params = SamplingParams(
temperature=1.0,
max_tokens=1000,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(json=json_schema))
outputs = llm.generate(
@ -417,7 +428,8 @@ def test_structured_output(
"city": {
"type": "string"
}
}
},
"additionalProperties": False
},
"end": "</function>"
}],
@ -426,7 +438,7 @@ def test_structured_output(
sampling_params = SamplingParams(
temperature=0.0,
max_tokens=100,
max_tokens=4096,
guided_decoding=GuidedDecodingParams(
structural_tag=json.dumps(structural_tag_config)))