[Bugfix] Support logprobs when using guided_json and other constrained decoding fields (#4149)

This commit is contained in:
James Whedbee 2024-04-18 16:12:55 -05:00 committed by GitHub
parent 705578ae14
commit e1bb2fd52d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 1 deletions

View File

@ -723,6 +723,36 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
guided_decoding_backend: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role":
"user",
"content":
"The best language for type-safe systems programming is "
}]
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,

View File

@ -116,7 +116,9 @@ class OpenAIServing:
if num_output_top_logprobs:
logprobs.top_logprobs.append({
p.decoded_token: p.logprob
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)