From 27c1afe88b4ab8ac6b42c256053dd6da7dc79625 Mon Sep 17 00:00:00 2001 From: clark Date: Sun, 12 Jan 2025 21:12:53 +0800 Subject: [PATCH] fix ThreadProxy Signed-off-by: clark --- .../disagg_benchmarks/zmq/test_request.py | 5 +- vllm/entrypoints/disagg_connector.py | 63 ++++++++++++++----- vllm/entrypoints/launcher.py | 30 ++++----- 3 files changed, 61 insertions(+), 37 deletions(-) diff --git a/benchmarks/disagg_benchmarks/zmq/test_request.py b/benchmarks/disagg_benchmarks/zmq/test_request.py index 5aa66ebaf7beb..b881aca790526 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_request.py +++ b/benchmarks/disagg_benchmarks/zmq/test_request.py @@ -41,7 +41,7 @@ async def test_connect_completions(session): async for chunk in response.content.iter_chunked(1024): try: decoded_chunk = chunk.decode('utf-8') - print(f"Decoded chunk: {decoded_chunk!r}") + # print(f"Decoded chunk: {decoded_chunk!r}") responseText += decoded_chunk except UnicodeDecodeError: print(f"Error decoding chunk: {chunk!r}") @@ -55,6 +55,7 @@ async def test_connect_completions(session): response.json())) else: print(f"Request failed with status code {response.status}") + print(f"Response : {await response.json()}") print(f"baseurl {base_url}") print(f"response data {extract_data(responseText)}") except aiohttp.ClientError as e: @@ -98,7 +99,7 @@ def extract_data(responseText): async def main(): async with aiohttp.ClientSession() as session: tasks = [] - for _ in range(1): + for _ in range(2): tasks.append(test_connect_completions(session)) await asyncio.gather(*tasks) diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 71beeaf1e6083..a79b37658268e 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 +import asyncio import json import signal +import traceback import uuid # from fastapi.lifespan import Lifespan from asyncio import Queue @@ -17,12 +19,14 @@ from fastapi.responses import JSONResponse, StreamingResponse from vllm.logger import init_logger # default prefill and decode addr +time_out = 3 fastapi_port = 8001 prefill_addr = "ipc://localhost:7010" -socket_prefill_num = 5 +socket_prefill_num = 20 decode_addr = "ipc://localhost:7020" -socket_decode_num = 5 +socket_decode_num = 20 context_type_json = "application/json" +context_type_error = "error" # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.connect') @@ -51,7 +55,7 @@ app = FastAPI(lifespan=lifespan) # create async socket pool with num_sockets use ZMQ_DEALER async def create_socket_pool(url: str, num_sockets: int, zmqctx: zmq.asyncio.Context) -> Queue: - sockets: Queue = Queue() + sockets: Queue[zmq.Socket] = Queue() for i in range(num_sockets): sock = zmqctx.socket(zmq.DEALER) identity = f"worker-{i}-{uuid.uuid4()}" @@ -66,20 +70,23 @@ async def create_socket_pool(url: str, num_sockets: int, # select a socket and execute task async def execute_task_async(route: str, headers: dict, request: dict, sockets: Queue): - sock = await sockets.get() + sock: zmq.Socket = await sockets.get() try: requestBody = json.dumps(request) headersJson = json.dumps(headers) logger.info("Sending requestBody: %s to %s with headers: %s", requestBody, route, headersJson) - await sock.send_multipart( + await asyncio.wait_for(sock.send_multipart( [route.encode(), headersJson.encode(), - requestBody.encode()]) + requestBody.encode()]), + timeout=time_out) logger.info("Sent end") while True: logger.info("Waiting for reply") - [contentType, reply] = await sock.recv_multipart() + [contentType, + reply] = await asyncio.wait_for(sock.recv_multipart(), + timeout=time_out) contentType_str = contentType.decode() reply_str = reply.decode() logger.info("Received result: %s, %s", contentType_str, reply_str) @@ -91,6 +98,11 @@ async def execute_task_async(route: str, headers: dict, request: dict, if "[DONE]" in reply_str: logger.info("Received stop signal, return socket") break + except asyncio.TimeoutError: + logger.error(traceback.format_exc()) + logger.error("Timeout, return socket: %s", + sock.getsockopt(zmq.IDENTITY)) + yield (context_type_error, "System Error") finally: await sockets.put(sock) @@ -101,16 +113,30 @@ async def generate_stream_response(fisrt_reply: str, async for _, reply in generator: yield reply - +async def prefill(route: str, header: dict, original_request_data: dict): + logger.info("start prefill") + generator = execute_task_async(route, header, original_request_data, + app.state.sockets_prefill) + async for contentType, reply in generator: + logger.info("contentType: %s, reply: %s", contentType, reply) + if context_type_error == contentType: + response = JSONResponse({"error": reply}) + response.status_code = 500 + return response + return True + async def decode(route: str, header: dict, original_request_data: dict): logger.info("start decode") generator = execute_task_async(route, header, original_request_data, app.state.sockets_decode) - logger.info("finish decode") async for contentType, reply in generator: logger.info("contentType: %s, reply: %s", contentType, reply) - if context_type_json == contentType: + if context_type_error == contentType: + response = JSONResponse({"error": reply}) + response.status_code = 500 + return response + elif context_type_json == contentType: return JSONResponse(reply) else: return StreamingResponse(generate_stream_response( @@ -135,12 +161,17 @@ async def chat_completions(request: Request): prefill_request['max_tokens'] = 1 route = "/v1/completions" # finish prefill - async for _ in execute_task_async(route, header, prefill_request, - app.state.sockets_prefill): - continue - - logger.info("finish prefill start decode") - response = await decode(route, header, original_request_data) + try: + prefill_response = await prefill(route, header, prefill_request) + if isinstance(prefill_response, JSONResponse): + return prefill_response + logger.info("finish prefill start decode") + response = await decode(route, header, original_request_data) + logger.info("finish decode") + except Exception as e: + logger.error("Error occurred in disagg prefill proxy server, %s", + e) + response = JSONResponse({"error": {"message": str(e)}}) return response except Exception as e: diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 35fbd517c8456..d96ad42cbbff8 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -9,6 +9,7 @@ from typing import Any, Optional import uvicorn import zmq import zmq.asyncio +import zmq.devices from fastapi import FastAPI, Request, Response from vllm import envs @@ -78,37 +79,28 @@ async def serve_http(app: FastAPI, return server.shutdown() -def proxy(clients_addr: str, workers_addr: str, - ctx: zmq.asyncio.Context) -> None: - in_socket = ctx.socket(zmq.ROUTER) - in_socket.bind(clients_addr) - out_socket = ctx.socket(zmq.DEALER) - out_socket.bind(workers_addr) - try: - zmq.proxy(in_socket, out_socket) - except zmq.ContextTerminated: - print("proxy terminated") - in_socket.close() - out_socket.close() - - async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: """Server routine""" logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, zmq_server_port) - workers_addr = "inproc://workers" + # different zmq context can't communicate use inproc + workers_addr = "ipc://workers" clients_addr = f"ipc://127.0.0.1:{zmq_server_port}" # Prepare our context and sockets context = zmq.asyncio.Context.instance() try: tasks = [ asyncio.create_task(worker_routine(workers_addr, app, context, i)) - for i in range(5) + for i in range(20) ] logger.info("zmq tasks: %s", tasks) - proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr, - context) - await asyncio.gather(*tasks, proxy_task) + # thread safety proxy create socket in the background: + # https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices + thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER) + thread_proxy.bind_in(clients_addr) + thread_proxy.bind_out(workers_addr) + thread_proxy.start() + await asyncio.gather(*tasks) except KeyboardInterrupt: print("ZMQ Server interrupted") except zmq.ZMQError as e: