mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 15:05:45 +08:00
Check the max prompt length for the OpenAI completions API (#472)
This commit is contained in:
parent
735ecfff61
commit
66c54aa9c3
@ -120,7 +120,7 @@ async def check_length(request, prompt):
|
|||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
if token_num + request.max_tokens > max_model_len:
|
if token_num + request.max_tokens > max_model_len:
|
||||||
return create_error_response(
|
return input_ids, create_error_response(
|
||||||
HTTPStatus.BAD_REQUEST,
|
HTTPStatus.BAD_REQUEST,
|
||||||
f"This model's maximum context length is {max_model_len} tokens. "
|
f"This model's maximum context length is {max_model_len} tokens. "
|
||||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||||
@ -129,7 +129,7 @@ async def check_length(request, prompt):
|
|||||||
f"Please reduce the length of the messages or completion.",
|
f"Please reduce the length of the messages or completion.",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return input_ids, None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
@ -191,7 +191,7 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
"logit_bias is not currently supported")
|
"logit_bias is not currently supported")
|
||||||
|
|
||||||
prompt = await get_gen_prompt(request)
|
prompt = await get_gen_prompt(request)
|
||||||
error_check_ret = await check_length(request, prompt)
|
token_ids, error_check_ret = await check_length(request, prompt)
|
||||||
if error_check_ret is not None:
|
if error_check_ret is not None:
|
||||||
return error_check_ret
|
return error_check_ret
|
||||||
|
|
||||||
@ -215,7 +215,8 @@ async def create_chat_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
async def abort_request() -> None:
|
async def abort_request() -> None:
|
||||||
await engine.abort(request_id)
|
await engine.abort(request_id)
|
||||||
@ -386,6 +387,11 @@ async def create_completion(raw_request: Request):
|
|||||||
prompt = request.prompt[0]
|
prompt = request.prompt[0]
|
||||||
else:
|
else:
|
||||||
prompt = request.prompt
|
prompt = request.prompt
|
||||||
|
|
||||||
|
token_ids, error_check_ret = await check_length(request, prompt)
|
||||||
|
if error_check_ret is not None:
|
||||||
|
return error_check_ret
|
||||||
|
|
||||||
created_time = int(time.time())
|
created_time = int(time.time())
|
||||||
try:
|
try:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
@ -405,7 +411,8 @@ async def create_completion(raw_request: Request):
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||||
|
|
||||||
result_generator = engine.generate(prompt, sampling_params, request_id)
|
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||||
|
token_ids)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
# results. In addition, we do not stream the results when use beam search.
|
# results. In addition, we do not stream the results when use beam search.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user