diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index bff248711783..3721b047e43d 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None assert choice.logprobs.top_logprobs is not None - assert len(choice.logprobs.top_logprobs[0]) <= 1 + assert len(choice.logprobs.top_logprobs[0]) == 1 @pytest.mark.asyncio @@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None assert choice.logprobs.top_logprobs is not None - assert len(choice.logprobs.top_logprobs[0]) <= 6 + assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6 @pytest.mark.asyncio @@ -1217,8 +1217,9 @@ number: "1" | "2" "model_name", [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], ) +@pytest.mark.parametrize("logprobs_arg", [1, 0]) async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, - model_name: str): + model_name: str, logprobs_arg: int): tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) # test using text and token IDs for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]): @@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, max_tokens=5, temperature=0.0, echo=True, - logprobs=1) + logprobs=logprobs_arg) prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt @@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, and logprobs.token_logprobs[0] is None) assert (len(logprobs.top_logprobs) > 5 and logprobs.top_logprobs[0] is None) + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 assert len(logprobs.tokens) > 5 diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 2fb122edaf98..572878b5527d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -312,7 +312,7 @@ class OpenAIServingCompletion(OpenAIServing): elif request.echo and request.max_tokens > 0: token_ids = prompt_token_ids + output.token_ids top_logprobs = (prompt_logprobs + output.logprobs - if request.logprobs else None) + if request.logprobs is not None else None) output_text = prompt_text + output.text else: token_ids = output.token_ids diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 9969c45963e9..0b3b41e69d6b 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -233,7 +233,7 @@ def _prepare_seq_groups( logits = hidden_states[selected_token_indices] """ - if sampling_params.prompt_logprobs: + if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(model_output_idx, model_output_idx + prompt_logprob_len)) model_output_idx += prompt_logprob_len diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 63ec22d79694..67c03ad60008 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -427,7 +427,7 @@ class ModelRunner: [lora_id] * (query_len if seq_group_metadata.sampling_params and seq_group_metadata.sampling_params.prompt_logprobs - else 1)) + is not None else 1)) mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: