add /v1/completions stream support

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-07 22:33:35 +08:00
parent 905424ed65
commit bfde1688e7
4 changed files with 185 additions and 12 deletions

View File

@ -50,22 +50,22 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
return sockets return sockets
# select a scoket and execute task # select a scoket and execute task
async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list): async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
sock = await sockets.get() sock = await sockets.get()
try: try:
requestBody = json.dumps(request) requestBody = json.dumps(request)
headersJson = json.dumps(dict(headers)) headersJson = json.dumps(headers)
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}") logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()]) await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
logger.info(f"Sent end") logger.info(f"Sent end")
while True: while True:
logger.info(f"Waiting for reply") logger.info(f"Waiting for reply")
reply = await sock.recv_multipart() [contentType, reply] = await sock.recv_multipart()
logger.info(f"Received result: {reply}") logger.info(f"Received result: {contentType}, {reply}")
yield f"data: {reply[0].decode()}\n\n" reply = reply.decode()
if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode(): yield f"{reply}"
if "[DONE]" in reply:
logger.info(f"Received stop signal, return socket") logger.info(f"Received stop signal, return socket")
yield "data: [DONE]\n\n"
break break
finally: finally:
await sockets.put(sock) await sockets.put(sock)
@ -73,16 +73,20 @@ async def execute_task_async(route: str, headers: Headers, request: dict, socket
@app.post('/v1/connect/completions') @app.post('/v1/connect/completions')
async def chat_completions(request: Request): async def chat_completions(request: Request):
try: try:
# Add the X-Request-Id header to the raw headers list
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info(f"add X-Request-Id: {x_request_id}")
header["X-Request-Id"] = x_request_id
original_request_data = await request.json() original_request_data = await request.json()
header = request.headers logger.info(f"Received request: {original_request_data} header: {header}")
logger.info(f"Received request: {original_request_data}")
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request['max_tokens'] = 1 prefill_request['max_tokens'] = 1
route = "/v1/completions" route = "/v1/completions"
# finish prefill # finish prefill
async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill): async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
logger.info(f"{x}")
continue continue
# return decode # return decode

View File

@ -6,9 +6,13 @@ import socket
from http import HTTPStatus from http import HTTPStatus
from typing import Any, Optional from typing import Any, Optional
import zmq
import zmq.asyncio
import uvicorn import uvicorn
from fastapi import FastAPI, Request, Response from fastapi import FastAPI, Request, Response
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm import envs from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError from vllm.engine.multiprocessing import MQEngineDeadError
@ -74,6 +78,37 @@ async def serve_http(app: FastAPI,
logger.info("Shutting down FastAPI HTTP server.") logger.info("Shutting down FastAPI HTTP server.")
return server.shutdown() return server.shutdown()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
url_worker = "inproc://workers"
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context()
# Socket to talk to clients
clients = context.socket(zmq.ROUTER)
clients.bind(url_client)
logger.info(f"ZMQ Server ROUTER started at {url_client}")
# Socket to talk to workers
workers = context.socket(zmq.DEALER)
workers.bind(url_worker)
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
try:
await asyncio.gather(*tasks, proxy_task)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
except zmq.ZMQError as e:
print("ZMQError:", e)
finally:
# We never get here but clean up anyhow
clients.close()
workers.close()
context.term()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None: def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server""" """Adds handlers for fatal errors that should crash the server"""

View File

@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.launcher import serve_http, serve_zmq
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser, from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args) validate_parsed_serve_args)
@ -1029,6 +1029,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
"s" if is_ssl else "", _listen_addr(sock_addr[0]), "s" if is_ssl else "", _listen_addr(sock_addr[0]),
sock_addr[1]) sock_addr[1])
zmq_server_port = args.zmq_server_port
if zmq_server_port is not None:
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
shutdown_task = await serve_http( shutdown_task = await serve_http(
app, app,
sock=sock, sock=sock,

View File

@ -0,0 +1,129 @@
import json
from typing import Optional
import zmq
import zmq.asyncio
import tempfile
import uuid
import httpx
import json
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.logger import init_logger
import traceback
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.connect_worker')
def base(app: FastAPI) -> OpenAIServing:
# Reuse the existing instance
return tokenization(app)
def models(app: FastAPI) -> OpenAIServingModels:
return app.state.openai_serving_models
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
return app.state.openai_serving_chat
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
return app.state.openai_serving_completion
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
return app.state.openai_serving_tokenization
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
headers_dict = json.loads(bytes_data.decode())
return httpx.Headers(headers_dict)
async def worker_routine(worker_url: str, app: FastAPI,
context: zmq.asyncio.Context = None, i: int = 0):
"""Worker routine"""
try:
# Socket to talk to dispatcher
socket = context.socket(zmq.DEALER)
worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_url)
logger.info(f"{worker_identity} started at {worker_url}")
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info(f"worker-{i} Received request identity: [{identity} ]")
url = url.decode()
logger.info(f"worker-{i} Received request url: [{url} ]")
header = bytes_to_headers(header)
logger.info(f"worker-{i} Received request headers: [{header} ]")
body = json.loads(body.decode())
logger.info(f"worker-{i} Received request body: [{body} ]")
logger.info(f"worker-{i} Calling OpenAI API")
completionRequest = CompletionRequest(**body)
createRequest = create_request(url, "POST", body, header)
generator = await create_completion(app, completionRequest, createRequest)
logger.info(f"worker-{i} Received response: [{generator} ]")
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context = json.loads(content)
context.append("status_code", generator.code)
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
else:
async for chunk in generator:
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
except Exception as e:
logger.error(f"Error in worker routine: {e} worker-{i}")
logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
handler = completion(app)
logger.info(f"zmq requset post: {request}")
if handler is None:
return base(app).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
logger.info(f"zmq requset end post: {generator}")
return generator
def create_request(path: str, method: str, body: bytes, headers: dict = None):
scope = {
'type': 'http',
'http_version': '1.1',
'method': method,
'path': path,
'headers': list(headers.items()) if headers else [],
}
if body:
scope['body'] = json.dumps(body).encode('utf-8')
async def receive():
return {
'type': 'http.request',
'body': scope.get('body', b''),
}
async def send(message):
pass
return Request(scope, receive=receive, send=send)
if __name__ == "__main__":
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))