diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py index de63f4ed218b6..0f459dd3d8574 100644 --- a/tests/entrypoints/openai/test_chat_echo.py +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -22,6 +22,8 @@ def server(): "--enforce-eager", "--max-model-len", "4080", + "--max-logprobs", # test prompt_logprobs equal to -1 + "151936" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -77,3 +79,23 @@ async def test_chat_session_with_echo_and_continue_final_message( else: assert message.content is not None and saying not in message.content assert message.role == "assistant" + + +@pytest.mark.asyncio +async def test_prompt_logprobs(client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Beijing is the capital of which country?" + }] + + completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + extra_body={"prompt_logprobs": -1}, + ) + + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8ecb1a8239c35..6b4c3f531dbce 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -822,13 +822,17 @@ class ChatCompletionRequest(OpenAIBaseModel): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 + or prompt_logprobs == -1): raise ValueError( "`prompt_logprobs` are not available when `stream=True`.") - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") - + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError( + "`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError("`prompt_logprobs=-1` is only supported with " + "vLLM engine V1.") if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0: raise ValueError("`top_logprobs` must be a positive value.") @@ -1246,13 +1250,17 @@ class CompletionRequest(OpenAIBaseModel): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and prompt_logprobs > 0: + if data.get("stream") and (prompt_logprobs > 0 + or prompt_logprobs == -1): raise ValueError( "`prompt_logprobs` are not available when `stream=True`.") - if prompt_logprobs < 0: - raise ValueError("`prompt_logprobs` must be a positive value.") - + if prompt_logprobs < 0 and prompt_logprobs != -1: + raise ValueError( + "`prompt_logprobs` must be a positive value or -1.") + if prompt_logprobs == -1 and not envs.VLLM_USE_V1: + raise ValueError("`prompt_logprobs=-1` is only supported with " + "vLLM engine V1.") if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.")