mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-10 11:41:20 +08:00
create proxy sockets in the proxy function for thread safety
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
7fbf70db57
commit
ee6607332e
@ -7,7 +7,7 @@ import aiohttp
|
|||||||
# test connect completions we assume prefill and decode are on the same node
|
# test connect completions we assume prefill and decode are on the same node
|
||||||
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
|
# 1. node:vllm serve facebook/opt-125m --port 7001 --zmq-server-port 7010 \
|
||||||
# --chat-template ~/vllm/examples/template_chatglm2.jinja
|
# --chat-template ~/vllm/examples/template_chatglm2.jinja
|
||||||
# 2. vllm connect --prefill-addr nodeIp:7010 --decode-addr nodeIp:7010
|
# 2. vllm connect --prefill-addr 127.0.0.1:7010 --decode-addr 127.0.0.1:7010
|
||||||
# 3. python test_request.py
|
# 3. python test_request.py
|
||||||
async def test_connect_completions(session):
|
async def test_connect_completions(session):
|
||||||
try:
|
try:
|
||||||
@ -68,11 +68,12 @@ def is_json(data):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def extract_data(responseText):
|
def extract_data(responseText):
|
||||||
|
reply = ""
|
||||||
if responseText == "":
|
if responseText == "":
|
||||||
return ""
|
return reply
|
||||||
if is_json(responseText):
|
if is_json(responseText):
|
||||||
return responseText
|
return responseText
|
||||||
reply = ""
|
|
||||||
for data in responseText.split("\n\n"):
|
for data in responseText.split("\n\n"):
|
||||||
if data.startswith('data: '):
|
if data.startswith('data: '):
|
||||||
content = data[6:]
|
content = data[6:]
|
||||||
|
|||||||
@ -78,6 +78,20 @@ async def serve_http(app: FastAPI,
|
|||||||
return server.shutdown()
|
return server.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
def proxy(clients_addr: str, workers_addr: str,
|
||||||
|
ctx: zmq.asyncio.Context) -> None:
|
||||||
|
in_socket = ctx.socket(zmq.ROUTER)
|
||||||
|
in_socket.bind(clients_addr)
|
||||||
|
out_socket = ctx.socket(zmq.DEALER)
|
||||||
|
out_socket.bind(workers_addr)
|
||||||
|
try:
|
||||||
|
zmq.proxy(in_socket, out_socket)
|
||||||
|
except zmq.ContextTerminated:
|
||||||
|
print("proxy terminated")
|
||||||
|
in_socket.close()
|
||||||
|
out_socket.close()
|
||||||
|
|
||||||
|
|
||||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||||
"""Server routine"""
|
"""Server routine"""
|
||||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||||
@ -85,24 +99,15 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
|||||||
workers_addr = "inproc://workers"
|
workers_addr = "inproc://workers"
|
||||||
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||||
# Prepare our context and sockets
|
# Prepare our context and sockets
|
||||||
context = zmq.asyncio.Context()
|
context = zmq.asyncio.Context.instance()
|
||||||
|
|
||||||
# Socket to talk to clients
|
|
||||||
clients = context.socket(zmq.ROUTER)
|
|
||||||
clients.bind(clients_addr)
|
|
||||||
logger.info("ZMQ Server ROUTER started at %s", clients_addr)
|
|
||||||
# Socket to talk to workers
|
|
||||||
workers = context.socket(zmq.DEALER)
|
|
||||||
workers.bind(workers_addr)
|
|
||||||
logger.info("ZMQ Worker DEALER started at %s", workers_addr)
|
|
||||||
|
|
||||||
tasks = [
|
|
||||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
|
||||||
for i in range(5)
|
|
||||||
]
|
|
||||||
proxy_task = asyncio.to_thread(zmq.proxy, clients, workers)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
tasks = [
|
||||||
|
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||||
|
for i in range(5)
|
||||||
|
]
|
||||||
|
logger.info("zmq tasks: %s", tasks)
|
||||||
|
proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr,
|
||||||
|
context)
|
||||||
await asyncio.gather(*tasks, proxy_task)
|
await asyncio.gather(*tasks, proxy_task)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("ZMQ Server interrupted")
|
print("ZMQ Server interrupted")
|
||||||
@ -110,8 +115,6 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
|||||||
print("ZMQError:", e)
|
print("ZMQError:", e)
|
||||||
finally:
|
finally:
|
||||||
# We never get here but clean up anyhow
|
# We never get here but clean up anyhow
|
||||||
clients.close()
|
|
||||||
workers.close()
|
|
||||||
context.destroy(linger=0)
|
context.destroy(linger=0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user