add /v1/completions stream support

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-07 22:33:35 +08:00
parent 905424ed65
commit bfde1688e7
4 changed files with 185 additions and 12 deletions

View File

@ -50,22 +50,22 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
return sockets
# select a scoket and execute task
async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list):
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
sock = await sockets.get()
try:
requestBody = json.dumps(request)
headersJson = json.dumps(dict(headers))
headersJson = json.dumps(headers)
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
logger.info(f"Sent end")
while True:
logger.info(f"Waiting for reply")
reply = await sock.recv_multipart()
logger.info(f"Received result: {reply}")
yield f"data: {reply[0].decode()}\n\n"
if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode():
[contentType, reply] = await sock.recv_multipart()
logger.info(f"Received result: {contentType}, {reply}")
reply = reply.decode()
yield f"{reply}"
if "[DONE]" in reply:
logger.info(f"Received stop signal, return socket")
yield "data: [DONE]\n\n"
break
finally:
await sockets.put(sock)
@ -73,16 +73,20 @@ async def execute_task_async(route: str, headers: Headers, request: dict, socket
@app.post('/v1/connect/completions')
async def chat_completions(request: Request):
try:
# Add the X-Request-Id header to the raw headers list
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info(f"add X-Request-Id: {x_request_id}")
header["X-Request-Id"] = x_request_id
original_request_data = await request.json()
header = request.headers
logger.info(f"Received request: {original_request_data}")
logger.info(f"Received request: {original_request_data} header: {header}")
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
logger.info(f"{x}")
continue
# return decode

View File

@ -6,9 +6,13 @@ import socket
from http import HTTPStatus
from typing import Any, Optional
import zmq
import zmq.asyncio
import uvicorn
from fastapi import FastAPI, Request, Response
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
@ -72,8 +76,39 @@ async def serve_http(app: FastAPI,
"port %s is used by process %s launched with command:\n%s",
port, process, " ".join(process.cmdline()))
logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown()
return server.shutdown()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
url_worker = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context()
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(url_client)
logger.info(f"ZMQ Server ROUTER started at {url_client}")
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(url_worker)
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
try:
await asyncio.gather(*tasks, proxy_task)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
except zmq.ZMQError as e:
print("ZMQError:", e)
finally:
# We never get here but clean up anyhow
clients.close()
workers.close()
context.term()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

View File

@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.launcher import serve_http, serve_zmq
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)
@ -1029,6 +1029,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1])
zmq_server_port = args.zmq_server_port
if zmq_server_port is not None:
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
shutdown_task = await serve_http(
app,
sock=sock,

View File

@ -0,0 +1,129 @@
import json
from typing import Optional
import zmq
import zmq.asyncio
import tempfile
import uuid
import httpx
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.logger import init_logger
import traceback
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.connect_worker')
def base(app: FastAPI) -> OpenAIServing:
# Reuse the existing instance
return tokenization(app)
def models(app: FastAPI) -> OpenAIServingModels:
return app.state.openai_serving_models
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
return app.state.openai_serving_chat
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
return app.state.openai_serving_completion
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
return app.state.openai_serving_tokenization
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
headers_dict = json.loads(bytes_data.decode())
return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI,
context: zmq.asyncio.Context = None, i: int = 0):
"""Worker routine"""
try:
# Socket to talk to dispatcher
socket = context.socket(zmq.DEALER)
worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_url)
logger.info(f"{worker_identity} started at {worker_url}")
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info(f"worker-{i} Received request identity: [{identity} ]")
url = url.decode()
logger.info(f"worker-{i} Received request url: [{url} ]")
header = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{header} ]")
body = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body} ]")
logger.info(f"worker-{i} Calling OpenAI API")
completionRequest = CompletionRequest(**body)
createRequest = create_request(url, "POST", body, header)
generator = await create_completion(app, completionRequest, createRequest)
logger.info(f"worker-{i} Received response: [{generator} ]")
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context = json.loads(content)
context.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
else:
async for chunk in generator:
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
except Exception as e:
logger.error(f"Error in worker routine: {e} worker-{i}")
logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
handler = completion(app)
logger.info(f"zmq requset post: {request}")
if handler is None:
return base(app).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
logger.info(f"zmq requset end post: {generator}")
return generator
def create_request(path: str, method: str, body: bytes, headers: dict = None):
scope = {
'type': 'http',
'http_version': '1.1',
'method': method,
'path': path,
'headers': list(headers.items()) if headers else [],
}
if body:
scope['body'] = json.dumps(body).encode('utf-8')
async def receive():
return {
'type': 'http.request',
'body': scope.get('body', b''),
}
async def send(message):
pass
return Request(scope, receive=receive, send=send)
if __name__ == "__main__":
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))