[Bugfix] OpenAI entrypoint limits logprobs while ignoring server defined --max-logprobs (#5312)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
maor-ps 2024-06-11 05:30:31 +03:00 committed by GitHub
parent a008629807
commit 351d5e7b82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 9 deletions

View File

@ -264,7 +264,9 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=21,
)
...
with pytest.raises(
@ -274,7 +276,9 @@ async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
# vLLM has higher default max_logprobs (20 instead of 5) to support
# both Completion API and Chat Completion API
logprobs=30,
stream=True,
)
async for chunk in stream:

View File

@ -100,7 +100,7 @@ class ModelConfig:
enforce_eager: bool = False,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 5,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,

View File

@ -48,7 +48,7 @@ class EngineArgs:
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_logprobs: int = 5 # OpenAI default value
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None

View File

@ -322,9 +322,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
elif data["top_logprobs"] < 0:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
"`top_logprobs` must be a value a positive value.")
return data
@ -478,9 +478,8 @@ class CompletionRequest(OpenAIBaseModel):
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
"logprobs"] is not None and not data["logprobs"] >= 0:
raise ValueError("if passed, `logprobs` must be a positive value.")
return data
@model_validator(mode="before")