From e06f504a761aba85ba34472d6c2544b626c311a8 Mon Sep 17 00:00:00 2001 From: WanMok <16273544+wanmok@users.noreply.github.com> Date: Fri, 11 Aug 2023 12:14:34 -0700 Subject: [PATCH] Supports tokens and arrays of tokens as inputs to the OpenAI completion API (#715) --- vllm/entrypoints/openai/api_server.py | 58 ++++++++++++++++++++------- vllm/entrypoints/openai/protocol.py | 3 +- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8acea787c1b7d..97d097e60f315 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -3,18 +3,18 @@ import argparse import asyncio -from http import HTTPStatus import json import time -from typing import AsyncGenerator, Dict, List, Optional -from packaging import version +from http import HTTPStatus +from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union import fastapi +import uvicorn from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse -import uvicorn +from packaging import version from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -115,8 +115,18 @@ async def get_gen_prompt(request) -> str: return prompt -async def check_length(request, prompt): - input_ids = tokenizer(prompt).input_ids +async def check_length( + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None +) -> Tuple[List[int], Optional[JSONResponse]]: + assert (not (prompt is None and prompt_ids is None) + and not (prompt is not None and prompt_ids is not None) + ), "Either prompt or prompt_ids should be provided." + if prompt_ids is not None: + input_ids = prompt_ids + else: + input_ids = tokenizer(prompt).input_ids token_num = len(input_ids) if token_num + request.max_tokens > max_model_len: @@ -191,7 +201,7 @@ async def create_chat_completion(raw_request: Request): "logit_bias is not currently supported") prompt = await get_gen_prompt(request) - token_ids, error_check_ret = await check_length(request, prompt) + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -376,19 +386,31 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" + + use_token_ids = False if isinstance(request.prompt, list): if len(request.prompt) == 0: return create_error_response(HTTPStatus.BAD_REQUEST, "please provide at least one prompt") - if len(request.prompt) > 1: - return create_error_response( - HTTPStatus.BAD_REQUEST, - "multiple prompts in a batch is not currently supported") - prompt = request.prompt[0] + first_element = request.prompt[0] + if isinstance(first_element, int): + use_token_ids = True + prompt = request.prompt + elif isinstance(first_element, (str, list)): + # TODO: handles multiple prompt case in list[list[int]] + if len(request.prompt) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "multiple prompts in a batch is not currently supported") + use_token_ids = not isinstance(first_element, str) + prompt = request.prompt[0] else: prompt = request.prompt - token_ids, error_check_ret = await check_length(request, prompt) + if use_token_ids: + _, error_check_ret = await check_length(request, prompt_ids=prompt) + else: + token_ids, error_check_ret = await check_length(request, prompt=prompt) if error_check_ret is not None: return error_check_ret @@ -411,8 +433,14 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = engine.generate(prompt, sampling_params, request_id, - token_ids) + if use_token_ids: + result_generator = engine.generate(None, + sampling_params, + request_id, + prompt_token_ids=prompt) + else: + result_generator = engine.generate(prompt, sampling_params, request_id, + token_ids) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use beam search. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c63e7a2964fc8..701f704234ad6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -74,7 +74,8 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: Union[str, List[str]] + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[List[int], List[List[int]], str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0