[Bugfix] Pass json-schema to GuidedDecodingParams and make test stronger (#9530)

This commit is contained in:
Chen Zhang 2024-10-19 17:05:02 -07:00 committed by GitHub
parent 8e3e7f2713
commit 5b59fe0f08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 29 additions and 9 deletions

View File

@ -851,14 +851,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI): async def test_response_format_json_schema(client: openai.AsyncOpenAI):
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
messages=[{ messages=[{
"role": "role": "user",
"user", "content": prompt
"content": ('what is 1+1? please respond with a JSON object, ' }],
'the format is {"result": 2}') )
content = resp.choices[0].message.content
assert content is not None
with pytest.raises((json.JSONDecodeError, AssertionError)):
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": prompt
}], }],
response_format={ response_format={
"type": "json_schema", "type": "json_schema",

View File

@ -314,9 +314,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
prompt_logprobs = self.top_logprobs prompt_logprobs = self.top_logprobs
guided_json_object = None guided_json_object = None
if (self.response_format is not None if self.response_format is not None:
and self.response_format.type == "json_object"): if self.response_format.type == "json_object":
guided_json_object = True guided_json_object = True
elif self.response_format.type == "json_schema":
json_schema = self.response_format.json_schema
assert json_schema is not None
self.guided_json = json_schema.json_schema
if self.guided_decoding_backend is None:
self.guided_decoding_backend = "lm-format-enforcer"
guided_decoding = GuidedDecodingParams.from_optional( guided_decoding = GuidedDecodingParams.from_optional(
json=self._get_guided_json_from_tool() or self.guided_json, json=self._get_guided_json_from_tool() or self.guided_json,
@ -537,8 +543,8 @@ class CompletionRequest(OpenAIBaseModel):
default=None, default=None,
description= description=
("Similar to chat completion, this parameter specifies the format of " ("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is " "output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
"supported."), "{'type': 'text' } is supported."),
) )
guided_json: Optional[Union[str, dict, BaseModel]] = Field( guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None, default=None,