mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-04 04:47:53 +08:00
[Bugfix] Pass json-schema to GuidedDecodingParams and make test stronger (#9530)
This commit is contained in:
parent
8e3e7f2713
commit
5b59fe0f08
@ -851,14 +851,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
|
||||||
|
prompt = 'what is 1+1? The format is "result": 2'
|
||||||
|
# Check that this prompt cannot lead to a valid JSON without json_schema
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
resp = await client.chat.completions.create(
|
resp = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=[{
|
messages=[{
|
||||||
"role":
|
"role": "user",
|
||||||
"user",
|
"content": prompt
|
||||||
"content": ('what is 1+1? please respond with a JSON object, '
|
}],
|
||||||
'the format is {"result": 2}')
|
)
|
||||||
|
content = resp.choices[0].message.content
|
||||||
|
assert content is not None
|
||||||
|
with pytest.raises((json.JSONDecodeError, AssertionError)):
|
||||||
|
loaded = json.loads(content)
|
||||||
|
assert loaded == {"result": 2}, loaded
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
resp = await client.chat.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
}],
|
}],
|
||||||
response_format={
|
response_format={
|
||||||
"type": "json_schema",
|
"type": "json_schema",
|
||||||
|
|||||||
@ -314,9 +314,15 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
prompt_logprobs = self.top_logprobs
|
prompt_logprobs = self.top_logprobs
|
||||||
|
|
||||||
guided_json_object = None
|
guided_json_object = None
|
||||||
if (self.response_format is not None
|
if self.response_format is not None:
|
||||||
and self.response_format.type == "json_object"):
|
if self.response_format.type == "json_object":
|
||||||
guided_json_object = True
|
guided_json_object = True
|
||||||
|
elif self.response_format.type == "json_schema":
|
||||||
|
json_schema = self.response_format.json_schema
|
||||||
|
assert json_schema is not None
|
||||||
|
self.guided_json = json_schema.json_schema
|
||||||
|
if self.guided_decoding_backend is None:
|
||||||
|
self.guided_decoding_backend = "lm-format-enforcer"
|
||||||
|
|
||||||
guided_decoding = GuidedDecodingParams.from_optional(
|
guided_decoding = GuidedDecodingParams.from_optional(
|
||||||
json=self._get_guided_json_from_tool() or self.guided_json,
|
json=self._get_guided_json_from_tool() or self.guided_json,
|
||||||
@ -537,8 +543,8 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description=
|
description=
|
||||||
("Similar to chat completion, this parameter specifies the format of "
|
("Similar to chat completion, this parameter specifies the format of "
|
||||||
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
|
"output. Only {'type': 'json_object'}, {'type': 'json_schema'} or "
|
||||||
"supported."),
|
"{'type': 'text' } is supported."),
|
||||||
)
|
)
|
||||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user