mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-07 14:47:07 +08:00
add vllm connect cmd
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
2a0cb78016
commit
5d20f389d6
122
vllm/entrypoints/connect.py
Normal file
122
vllm/entrypoints/connect.py
Normal file
@ -0,0 +1,122 @@
|
||||
import json
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from contextlib import asynccontextmanager
|
||||
# from fastapi.lifespan import Lifespan
|
||||
from asyncio import Queue
|
||||
import uuid
|
||||
import signal
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# default prefill and decode url
|
||||
url_prefill = "tcp://localhost:8110"
|
||||
socket_prefill_num = 5
|
||||
url_decode = "tcp://localhost:8220"
|
||||
socket_decode_num = 5
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.connect')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# create scoket pool with prefill and decode
|
||||
logger.info("start create_socket_pool")
|
||||
app.state.zmqctx = zmq.asyncio.Context()
|
||||
app.state.sockets_prefill = await create_socket_pool(app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_prefill")
|
||||
app.state.sockets_decode = await create_socket_pool(app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
|
||||
logger.info("success create_socket_pool sockets_decode")
|
||||
yield
|
||||
## close zmq context
|
||||
logger.info("term zmqctx")
|
||||
await app.state.zmqctx.term()
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||
async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context):
|
||||
sockets = Queue()
|
||||
for i in range(num_sockets):
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
sock.setsockopt(zmq.IDENTITY, identity.encode())
|
||||
sock.connect(url)
|
||||
logger.info(f"{identity} started at {url} {sockets.qsize()}")
|
||||
await sockets.put(sock)
|
||||
return sockets
|
||||
|
||||
# select a scoket and execute task
|
||||
async def execute_task_async(request: dict, sockets: list):
|
||||
sock = await sockets.get()
|
||||
try:
|
||||
requestBody = json.dumps(request)
|
||||
logger.info(f"Sending requestBody: {requestBody}")
|
||||
await sock.send(requestBody.encode())
|
||||
logger.info(f"Sent end")
|
||||
while True:
|
||||
logger.info(f"Waiting for reply")
|
||||
reply = await sock.recv_multipart()
|
||||
logger.info(f"Received result: {reply}")
|
||||
yield f"data: {reply[0].decode()}\n\n"
|
||||
if "finish_reason" in reply[0].decode() and "stop" in reply[0].decode():
|
||||
logger.info(f"Received stop signal, return socket")
|
||||
yield "data: [DONE]\n\n"
|
||||
break
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
|
||||
@app.post('/v1/connect/completions')
|
||||
async def chat_completions(request: Request):
|
||||
try:
|
||||
original_request_data = await request.json()
|
||||
logger.info(f"Received request: {original_request_data}")
|
||||
prefill_request = original_request_data.copy()
|
||||
# change max_tokens = 1 to let it only do prefill
|
||||
prefill_request['max_tokens'] = 1
|
||||
|
||||
# finish prefill
|
||||
async for x in execute_task_async(prefill_request, app.state.sockets_prefill):
|
||||
logger.info(f"{x}")
|
||||
continue
|
||||
|
||||
# return decode
|
||||
return StreamingResponse(execute_task_async(original_request_data, app.state.sockets_decode), media_type="text/event-stream")
|
||||
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
logger.error("Error occurred in disagg prefill proxy server")
|
||||
logger.error(e)
|
||||
logger.error("".join(traceback.format_exception(*exc_info)))
|
||||
|
||||
|
||||
async def run_connect(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM Connect start %s", args)
|
||||
logger.info(f"start connect {args} {uvicorn_kwargs}")
|
||||
logger.info(args.prefill_addr)
|
||||
|
||||
app.state.prefill_addr = f"tcp://{args.prefill_addr}" if args.prefill_addr is not None else url_prefill
|
||||
app.state.decode_addr = f"tcp://{args.decode_addr}" if args.decode_addr is not None else url_decode
|
||||
logger.info(f"start connect url_prefill: {app.state.prefill_addr} url_decode: {app.state.decode_addr}")
|
||||
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
# init uvicorn server
|
||||
config = uvicorn.Config(app, host="0.0.0.0", port=8001)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# url = 'tcp://127.0.0.1:5555'
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
Loading…
x
Reference in New Issue
Block a user