mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-18 08:14:27 +08:00
Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715)
This commit is contained in:
parent
462ae5220a
commit
e06f504a76
@ -3,18 +3,18 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
from http import HTTPStatus
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
from packaging import version
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from fastapi import BackgroundTasks, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
from packaging import version
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str:
|
||||
return prompt
|
||||
|
||||
|
||||
async def check_length(request, prompt):
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
async def check_length(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None
|
||||
) -> Tuple[List[int], Optional[JSONResponse]]:
|
||||
assert (not (prompt is None and prompt_ids is None)
|
||||
and not (prompt is not None and prompt_ids is not None)
|
||||
), "Either prompt or prompt_ids should be provided."
|
||||
if prompt_ids is not None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = tokenizer(prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if token_num + request.max_tokens > max_model_len:
|
||||
@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request):
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
prompt = await get_gen_prompt(request)
|
||||
token_ids, error_check_ret = await check_length(request, prompt)
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
@ -376,19 +386,31 @@ async def create_completion(raw_request: Request):
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
|
||||
use_token_ids = False
|
||||
if isinstance(request.prompt, list):
|
||||
if len(request.prompt) == 0:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"please provide at least one prompt")
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
prompt = request.prompt[0]
|
||||
first_element = request.prompt[0]
|
||||
if isinstance(first_element, int):
|
||||
use_token_ids = True
|
||||
prompt = request.prompt
|
||||
elif isinstance(first_element, (str, list)):
|
||||
# TODO: handles multiple prompt case in list[list[int]]
|
||||
if len(request.prompt) > 1:
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"multiple prompts in a batch is not currently supported")
|
||||
use_token_ids = not isinstance(first_element, str)
|
||||
prompt = request.prompt[0]
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
token_ids, error_check_ret = await check_length(request, prompt)
|
||||
if use_token_ids:
|
||||
_, error_check_ret = await check_length(request, prompt_ids=prompt)
|
||||
else:
|
||||
token_ids, error_check_ret = await check_length(request, prompt=prompt)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
@ -411,8 +433,14 @@ async def create_completion(raw_request: Request):
|
||||
except ValueError as e:
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
|
||||
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
if use_token_ids:
|
||||
result_generator = engine.generate(None,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prompt_token_ids=prompt)
|
||||
else:
|
||||
result_generator = engine.generate(prompt, sampling_params, request_id,
|
||||
token_ids)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
|
||||
@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
prompt: Union[str, List[str]]
|
||||
# a string, array of strings, array of tokens, or array of token arrays
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
temperature: Optional[float] = 1.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user