diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8a2573fe2b0e..43eaa5c60df1 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -357,7 +357,11 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" - prompt = request.prompt + if isinstance(request.prompt, list): + assert len(request.prompt) == 1 + prompt = request.prompt[0] + else: + prompt = request.prompt created_time = int(time.time()) try: sampling_params = SamplingParams( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3728241edc03..6c45b507329b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -73,7 +73,7 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: str + prompt: Union[str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0