From d81abefd2ee8e1f4b46b3660ebdaf7b8e19c573a Mon Sep 17 00:00:00 2001 From: Tyler Rockwood Date: Sat, 24 Aug 2024 01:07:24 -0500 Subject: [PATCH] [Frontend] add json_schema support from OpenAI protocol (#7654) --- tests/entrypoints/openai/test_chat.py | 33 +++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 14 ++++++-- .../lm_format_enforcer_decoding.py | 7 ++++ .../guided_decoding/outlines_decoding.py | 7 ++++ 4 files changed, 59 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index afcb0f44befc5..ce5bf3d5d7ba0 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -837,6 +837,39 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio +async def test_response_format_json_schema(client: openai.AsyncOpenAI): + for _ in range(2): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": + "user", + "content": ('what is 1+1? please respond with a JSON object, ' + 'the format is {"result": 2}') + }], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "foo_test", + "schema": { + "type": "object", + "properties": { + "result": { + "type": "integer" + }, + }, + }, + } + }) + + content = resp.choices[0].message.content + assert content is not None + + loaded = json.loads(content) + assert loaded == {"result": 2}, loaded + + @pytest.mark.asyncio async def test_extra_fields(client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c46f5cf8ce663..0954b81595ef5 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -85,9 +85,19 @@ class UsageInfo(OpenAIBaseModel): completion_tokens: Optional[int] = 0 +class JsonSchemaResponseFormat(OpenAIBaseModel): + name: str + description: Optional[str] = None + # schema is the field in openai but that causes conflicts with pydantic so + # instead use json_schema with an alias + json_schema: Optional[Dict[str, Any]] = Field(default=None, alias='schema') + strict: Optional[bool] = None + + class ResponseFormat(OpenAIBaseModel): - # type must be "json_object" or "text" - type: Literal["text", "json_object"] + # type must be "json_schema", "json_object" or "text" + type: Literal["text", "json_object", "json_schema"] + json_schema: Optional[JsonSchemaResponseFormat] = None class StreamOptions(OpenAIBaseModel): diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index b2188c9cbc2bb..8de811a6fbc41 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -49,6 +49,13 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( and request.response_format.type == "json_object"): character_level_parser = JsonSchemaParser( None) # None means any json object + elif (request.response_format is not None + and request.response_format.type == "json_schema" + and request.response_format.json_schema is not None + and request.response_format.json_schema.json_schema is not None): + schema = _normalize_json_schema_object( + request.response_format.json_schema.json_schema) + character_level_parser = JsonSchemaParser(schema) else: return None diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bc62224dabecf..bfc658ef7d26b 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -127,6 +127,13 @@ def _get_guide_and_mode( and request.response_format is not None and request.response_format.type == "json_object"): return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR + elif (not isinstance(request, GuidedDecodingRequest) + and request.response_format is not None + and request.response_format.type == "json_schema" + and request.response_format.json_schema is not None + and request.response_format.json_schema.json_schema is not None): + json = json_dumps(request.response_format.json_schema.json_schema) + return json, GuidedDecodingMode.JSON else: return None, None