diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e91e7c1a07b2..d260396e47c4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -130,6 +130,8 @@ async def check_length( input_ids = tokenizer(prompt).input_ids token_num = len(input_ids) + if request.max_tokens is None: + request.max_tokens = max_model_len - token_num if token_num + request.max_tokens > max_model_len: return input_ids, create_error_response( HTTPStatus.BAD_REQUEST, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b45e4146e29f..473400a7faf9 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -58,7 +58,7 @@ class ChatCompletionRequest(BaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 n: Optional[int] = 1 - max_tokens: Optional[int] = 16 + max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = Field(default_factory=list) stream: Optional[bool] = False presence_penalty: Optional[float] = 0.0