fix ThreadProxy

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-01-12 21:12:53 +08:00
parent ee6607332e
commit 27c1afe88b
3 changed files with 61 additions and 37 deletions

View File

@ -41,7 +41,7 @@ async def test_connect_completions(session):
async for chunk in response.content.iter_chunked(1024):
try:
decoded_chunk = chunk.decode('utf-8')
print(f"Decoded chunk: {decoded_chunk!r}")
# print(f"Decoded chunk: {decoded_chunk!r}")
responseText += decoded_chunk
except UnicodeDecodeError:
print(f"Error decoding chunk: {chunk!r}")
@ -55,6 +55,7 @@ async def test_connect_completions(session):
response.json()))
else:
print(f"Request failed with status code {response.status}")
print(f"Response : {await response.json()}")
print(f"baseurl {base_url}")
print(f"response data {extract_data(responseText)}")
except aiohttp.ClientError as e:
@ -98,7 +99,7 @@ def extract_data(responseText):
async def main():
async with aiohttp.ClientSession() as session:
tasks = []
for _ in range(1):
for _ in range(2):
tasks.append(test_connect_completions(session))
await asyncio.gather(*tasks)

View File

@ -1,7 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import signal
import traceback
import uuid
# from fastapi.lifespan import Lifespan
from asyncio import Queue
@ -17,12 +19,14 @@ from fastapi.responses import JSONResponse, StreamingResponse
from vllm.logger import init_logger
# default prefill and decode addr
time_out = 3
fastapi_port = 8001
prefill_addr = "ipc://localhost:7010"
socket_prefill_num = 5
socket_prefill_num = 20
decode_addr = "ipc://localhost:7020"
socket_decode_num = 5
socket_decode_num = 20
context_type_json = "application/json"
context_type_error = "error"
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.connect')
@ -51,7 +55,7 @@ app = FastAPI(lifespan=lifespan)
# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int,
zmqctx: zmq.asyncio.Context) -> Queue:
sockets: Queue = Queue()
sockets: Queue[zmq.Socket] = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
identity = f"worker-{i}-{uuid.uuid4()}"
@ -66,20 +70,23 @@ async def create_socket_pool(url: str, num_sockets: int,
# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict,
sockets: Queue):
sock = await sockets.get()
sock: zmq.Socket = await sockets.get()
try:
requestBody = json.dumps(request)
headersJson = json.dumps(headers)
logger.info("Sending requestBody: %s to %s with headers: %s",
requestBody, route, headersJson)
await sock.send_multipart(
await asyncio.wait_for(sock.send_multipart(
[route.encode(),
headersJson.encode(),
requestBody.encode()])
requestBody.encode()]),
timeout=time_out)
logger.info("Sent end")
while True:
logger.info("Waiting for reply")
[contentType, reply] = await sock.recv_multipart()
[contentType,
reply] = await asyncio.wait_for(sock.recv_multipart(),
timeout=time_out)
contentType_str = contentType.decode()
reply_str = reply.decode()
logger.info("Received result: %s, %s", contentType_str, reply_str)
@ -91,6 +98,11 @@ async def execute_task_async(route: str, headers: dict, request: dict,
if "[DONE]" in reply_str:
logger.info("Received stop signal, return socket")
break
except asyncio.TimeoutError:
logger.error(traceback.format_exc())
logger.error("Timeout, return socket: %s",
sock.getsockopt(zmq.IDENTITY))
yield (context_type_error, "System Error")
finally:
await sockets.put(sock)
@ -101,16 +113,30 @@ async def generate_stream_response(fisrt_reply: str,
async for _, reply in generator:
yield reply
async def prefill(route: str, header: dict, original_request_data: dict):
logger.info("start prefill")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_prefill)
async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_error == contentType:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
return True
async def decode(route: str, header: dict, original_request_data: dict):
logger.info("start decode")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_decode)
logger.info("finish decode")
async for contentType, reply in generator:
logger.info("contentType: %s, reply: %s", contentType, reply)
if context_type_json == contentType:
if context_type_error == contentType:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
elif context_type_json == contentType:
return JSONResponse(reply)
else:
return StreamingResponse(generate_stream_response(
@ -135,12 +161,17 @@ async def chat_completions(request: Request):
prefill_request['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
async for _ in execute_task_async(route, header, prefill_request,
app.state.sockets_prefill):
continue
logger.info("finish prefill start decode")
response = await decode(route, header, original_request_data)
try:
prefill_response = await prefill(route, header, prefill_request)
if isinstance(prefill_response, JSONResponse):
return prefill_response
logger.info("finish prefill start decode")
response = await decode(route, header, original_request_data)
logger.info("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse({"error": {"message": str(e)}})
return response
except Exception as e:

View File

@ -9,6 +9,7 @@ from typing import Any, Optional
import uvicorn
import zmq
import zmq.asyncio
import zmq.devices
from fastapi import FastAPI, Request, Response
from vllm import envs
@ -78,37 +79,28 @@ async def serve_http(app: FastAPI,
return server.shutdown()
def proxy(clients_addr: str, workers_addr: str,
ctx: zmq.asyncio.Context) -> None:
in_socket = ctx.socket(zmq.ROUTER)
in_socket.bind(clients_addr)
out_socket = ctx.socket(zmq.DEALER)
out_socket.bind(workers_addr)
try:
zmq.proxy(in_socket, out_socket)
except zmq.ContextTerminated:
print("proxy terminated")
in_socket.close()
out_socket.close()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
workers_addr = "inproc://workers"
# different zmq context can't communicate use inproc
workers_addr = "ipc://workers"
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context.instance()
try:
tasks = [
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(5)
for i in range(20)
]
logger.info("zmq tasks: %s", tasks)
proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr,
context)
await asyncio.gather(*tasks, proxy_task)
# thread safety proxy create socket in the background:
# https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices
thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER)
thread_proxy.bind_in(clients_addr)
thread_proxy.bind_out(workers_addr)
thread_proxy.start()
await asyncio.gather(*tasks)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
except zmq.ZMQError as e: