[Bugfix] Support prompt_logprobs==0 (#5217)

This commit is contained in:
Toshiki Kataoka 2024-06-04 09:59:30 +09:00 committed by GitHub
parent f775a07e30
commit 06b2550cbb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 11 additions and 7 deletions

View File

@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1
assert len(choice.logprobs.top_logprobs[0]) == 1
@pytest.mark.asyncio
@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@ -1217,8 +1217,9 @@ number: "1" | "2"
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
@pytest.mark.parametrize("logprobs_arg", [1, 0])
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
model_name: str):
model_name: str, logprobs_arg: int):
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
# test using text and token IDs
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
max_tokens=5,
temperature=0.0,
echo=True,
logprobs=1)
logprobs=logprobs_arg)
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
list) else prompt
@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
and logprobs.token_logprobs[0] is None)
assert (len(logprobs.top_logprobs) > 5
and logprobs.top_logprobs[0] is None)
for top_logprobs in logprobs.top_logprobs[1:]:
assert max(logprobs_arg,
1) <= len(top_logprobs) <= logprobs_arg + 1
assert len(logprobs.tokens) > 5

View File

@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing):
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = (prompt_logprobs + output.logprobs
if request.logprobs else None)
if request.logprobs is not None else None)
output_text = prompt_text + output.text
else:
token_ids = output.token_ids

View File

@ -233,7 +233,7 @@ def _prepare_seq_groups(
logits = hidden_states[selected_token_indices]
"""
if sampling_params.prompt_logprobs:
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(model_output_idx, model_output_idx + prompt_logprob_len))
model_output_idx += prompt_logprob_len

View File

@ -427,7 +427,7 @@ class ModelRunner:
[lora_id] *
(query_len if seq_group_metadata.sampling_params
and seq_group_metadata.sampling_params.prompt_logprobs
else 1))
is not None else 1))
mm_data = seq_group_metadata.multi_modal_data
if mm_data is not None: