Check the max prompt length for the OpenAI completions API (#472)

This commit is contained in:
Nicolas Basile 2023-08-08 17:43:49 -07:00 committed by GitHub
parent 735ecfff61
commit 66c54aa9c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -120,7 +120,7 @@ async def check_length(request, prompt):
token_num = len(input_ids) token_num = len(input_ids)
if token_num + request.max_tokens > max_model_len: if token_num + request.max_tokens > max_model_len:
return create_error_response( return input_ids, create_error_response(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
f"This model's maximum context length is {max_model_len} tokens. " f"This model's maximum context length is {max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens " f"However, you requested {request.max_tokens + token_num} tokens "
@ -129,7 +129,7 @@ async def check_length(request, prompt):
f"Please reduce the length of the messages or completion.", f"Please reduce the length of the messages or completion.",
) )
else: else:
return None return input_ids, None
@app.get("/v1/models") @app.get("/v1/models")
@ -191,7 +191,7 @@ async def create_chat_completion(raw_request: Request):
"logit_bias is not currently supported") "logit_bias is not currently supported")
prompt = await get_gen_prompt(request) prompt = await get_gen_prompt(request)
error_check_ret = await check_length(request, prompt) token_ids, error_check_ret = await check_length(request, prompt)
if error_check_ret is not None: if error_check_ret is not None:
return error_check_ret return error_check_ret
@ -215,7 +215,8 @@ async def create_chat_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id) result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
async def abort_request() -> None: async def abort_request() -> None:
await engine.abort(request_id) await engine.abort(request_id)
@ -386,6 +387,11 @@ async def create_completion(raw_request: Request):
prompt = request.prompt[0] prompt = request.prompt[0]
else: else:
prompt = request.prompt prompt = request.prompt
token_ids, error_check_ret = await check_length(request, prompt)
if error_check_ret is not None:
return error_check_ret
created_time = int(time.time()) created_time = int(time.time())
try: try:
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -405,7 +411,8 @@ async def create_completion(raw_request: Request):
except ValueError as e: except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
result_generator = engine.generate(prompt, sampling_params, request_id) result_generator = engine.generate(prompt, sampling_params, request_id,
token_ids)
# Similar to the OpenAI API, when n != best_of, we do not stream the # Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search. # results. In addition, we do not stream the results when use beam search.