mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-06 13:17:03 +08:00
add /v1/completions stream support
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
905424ed65
commit
bfde1688e7
@ -50,22 +50,22 @@ async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Con
|
||||
return sockets
|
||||
|
||||
# select a scoket and execute task
|
||||
async def execute_task_async(route: str, headers: Headers, request: dict, sockets: list):
|
||||
async def execute_task_async(route: str, headers: dict, request: dict, sockets: list):
|
||||
sock = await sockets.get()
|
||||
try:
|
||||
requestBody = json.dumps(request)
|
||||
headersJson = json.dumps(dict(headers))
|
||||
headersJson = json.dumps(headers)
|
||||
logger.info(f"Sending requestBody: {requestBody} to {route} with headers: {headersJson}")
|
||||
await sock.send_multipart([route.encode(), headersJson.encode(), requestBody.encode()])
|
||||
logger.info(f"Sent end")
|
||||
while True:
|
||||
logger.info(f"Waiting for reply")
|
||||
reply = await sock.recv_multipart()
|
||||
logger.info(f"Received result: {reply}")
|
||||
yield f"data: {reply[0].decode()}\n\n"
|
||||
if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode():
|
||||
[contentType, reply] = await sock.recv_multipart()
|
||||
logger.info(f"Received result: {contentType}, {reply}")
|
||||
reply = reply.decode()
|
||||
yield f"{reply}"
|
||||
if "[DONE]" in reply:
|
||||
logger.info(f"Received stop signal, return socket")
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
@ -73,16 +73,20 @@ async def execute_task_async(route: str, headers: Headers, request: dict, socket
|
||||
@app.post('/v1/connect/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
# Add the X-Request-Id header to the raw headers list
|
||||
x_request_id = str(uuid.uuid4())
|
||||
header = dict(request.headers)
|
||||
if header.get("X-Request-Id") is None:
|
||||
logger.info(f"add X-Request-Id: {x_request_id}")
|
||||
header["X-Request-Id"] = x_request_id
|
||||
original_request_data = await request.json()
|
||||
header = request.headers
|
||||
logger.info(f"Received request: {original_request_data}")
|
||||
logger.info(f"Received request: {original_request_data} header: {header}")
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request['max_tokens'] = 1
|
||||
route = "/v1/completions"
|
||||
# finish prefill
|
||||
async for x in execute_task_async(route, header, prefill_request, app.state.sockets_prefill):
|
||||
logger.info(f"{x}")
|
||||
continue
|
||||
|
||||
# return decode
|
||||
|
||||
@ -6,9 +6,13 @@ import socket
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Optional
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm.entrypoints.openai.connect_worker import worker_routine
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.multiprocessing import MQEngineDeadError
|
||||
@ -72,8 +76,39 @@ async def serve_http(app: FastAPI,
|
||||
"port %s is used by process %s launched with command:\n%s",
|
||||
port, process, " ".join(process.cmdline()))
|
||||
logger.info("Shutting down FastAPI HTTP server.")
|
||||
return server.shutdown()
|
||||
return server.shutdown()
|
||||
|
||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info(f"zmq Server start arg: {arg}, zmq_port: {zmq_server_port}")
|
||||
url_worker = "inproc://workers"
|
||||
url_client = f"tcp://0.0.0.0:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context()
|
||||
|
||||
# Socket to talk to clients
|
||||
clients = context.socket(zmq.ROUTER)
|
||||
clients.bind(url_client)
|
||||
logger.info(f"ZMQ Server ROUTER started at {url_client}")
|
||||
# Socket to talk to workers
|
||||
workers = context.socket(zmq.DEALER)
|
||||
workers.bind(url_worker)
|
||||
logger.info(f"ZMQ Worker DEALER started at {url_worker}")
|
||||
|
||||
tasks = [asyncio.create_task(worker_routine(url_worker, app, context, i)) for i in range(5)]
|
||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*tasks, proxy_task)
|
||||
except KeyboardInterrupt:
|
||||
print("ZMQ Server interrupted")
|
||||
except zmq.ZMQError as e:
|
||||
print("ZMQError:", e)
|
||||
finally:
|
||||
# We never get here but clean up anyhow
|
||||
clients.close()
|
||||
workers.close()
|
||||
context.term()
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.engine.multiprocessing.engine import run_mp_engine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.launcher import serve_http
|
||||
from vllm.entrypoints.launcher import serve_http, serve_zmq
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
@ -1029,6 +1029,11 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
"s" if is_ssl else "", _listen_addr(sock_addr[0]),
|
||||
sock_addr[1])
|
||||
|
||||
zmq_server_port = args.zmq_server_port
|
||||
if zmq_server_port is not None:
|
||||
logger.info("asyncio.create_task Starting ZMQ server at port %d", zmq_server_port)
|
||||
asyncio.create_task(serve_zmq(args, zmq_server_port, app))
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
sock=sock,
|
||||
|
||||
129
vllm/entrypoints/openai/connect_worker.py
Normal file
129
vllm/entrypoints/openai/connect_worker.py
Normal file
@ -0,0 +1,129 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import tempfile
|
||||
import uuid
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
ErrorResponse)
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.logger import init_logger
|
||||
import traceback
|
||||
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.openai.connect_worker')
|
||||
|
||||
def base(app: FastAPI) -> OpenAIServing:
|
||||
# Reuse the existing instance
|
||||
return tokenization(app)
|
||||
|
||||
|
||||
def models(app: FastAPI) -> OpenAIServingModels:
|
||||
return app.state.openai_serving_models
|
||||
|
||||
|
||||
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
|
||||
return app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
|
||||
return app.state.openai_serving_completion
|
||||
|
||||
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
|
||||
return app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
|
||||
headers_dict = json.loads(bytes_data.decode())
|
||||
return httpx.Headers(headers_dict)
|
||||
|
||||
async def worker_routine(worker_url: str, app: FastAPI,
|
||||
context: zmq.asyncio.Context = None, i: int = 0):
|
||||
"""Worker routine"""
|
||||
try:
|
||||
# Socket to talk to dispatcher
|
||||
socket = context.socket(zmq.DEALER)
|
||||
worker_identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
|
||||
socket.connect(worker_url)
|
||||
logger.info(f"{worker_identity} started at {worker_url}")
|
||||
while True:
|
||||
identity, url, header, body = await socket.recv_multipart()
|
||||
logger.info(f"worker-{i} Received request identity: [{identity} ]")
|
||||
url = url.decode()
|
||||
logger.info(f"worker-{i} Received request url: [{url} ]")
|
||||
header = bytes_to_headers(header)
|
||||
logger.info(f"worker-{i} Received request headers: [{header} ]")
|
||||
body = json.loads(body.decode())
|
||||
logger.info(f"worker-{i} Received request body: [{body} ]")
|
||||
logger.info(f"worker-{i} Calling OpenAI API")
|
||||
completionRequest = CompletionRequest(**body)
|
||||
createRequest = create_request(url, "POST", body, header)
|
||||
generator = await create_completion(app, completionRequest, createRequest)
|
||||
logger.info(f"worker-{i} Received response: [{generator} ]")
|
||||
if isinstance(generator, ErrorResponse):
|
||||
content = generator.model_dump_json()
|
||||
context = json.loads(content)
|
||||
context.append("status_code", generator.code)
|
||||
await socket.send_multipart([identity, b"application/json", json.dumps(context).encode()])
|
||||
elif isinstance(generator, CompletionResponse):
|
||||
await socket.send_multipart([identity, b"application/json", JSONResponse.render(content=generator.model_dump())])
|
||||
else:
|
||||
async for chunk in generator:
|
||||
logger.info(f"worker-{i} Sending response chunk: [{chunk} ]")
|
||||
await socket.send_multipart([identity, b"text/event-stream", chunk.encode()])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in worker routine: {e} worker-{i}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def create_completion(app: FastAPI, request: CompletionRequest, raw_request: Request):
|
||||
handler = completion(app)
|
||||
logger.info(f"zmq requset post: {request}")
|
||||
if handler is None:
|
||||
return base(app).create_error_response(
|
||||
message="The model does not support Completions API")
|
||||
|
||||
generator = await handler.create_completion(request, raw_request)
|
||||
logger.info(f"zmq requset end post: {generator}")
|
||||
return generator
|
||||
|
||||
|
||||
def create_request(path: str, method: str, body: bytes, headers: dict = None):
|
||||
scope = {
|
||||
'type': 'http',
|
||||
'http_version': '1.1',
|
||||
'method': method,
|
||||
'path': path,
|
||||
'headers': list(headers.items()) if headers else [],
|
||||
}
|
||||
if body:
|
||||
scope['body'] = json.dumps(body).encode('utf-8')
|
||||
async def receive():
|
||||
return {
|
||||
'type': 'http.request',
|
||||
'body': scope.get('body', b''),
|
||||
}
|
||||
async def send(message):
|
||||
pass
|
||||
return Request(scope, receive=receive, send=send)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user