diff --git a/benchmarks/disagg_benchmarks/zmq/test_request.py b/benchmarks/disagg_benchmarks/zmq/test_request.py index 87c3b34d0181e..1e3ce27bd85a0 100644 --- a/benchmarks/disagg_benchmarks/zmq/test_request.py +++ b/benchmarks/disagg_benchmarks/zmq/test_request.py @@ -21,6 +21,7 @@ async def test_connect_completions(session): "repetition_penalty": 1.2, "model": "facebook/opt-125m", "prompt": "Can you introduce vllm?", + # "stream": False, "stream": True, "stream_options": { "include_usage": True @@ -34,13 +35,19 @@ async def test_connect_completions(session): responseText = "" if response.status == 200: transfer_encoding = response.headers.get('Transfer-Encoding') + content_type = response.headers.get('Content-Type') + print(f"Transfer-Encoding: {transfer_encoding}") if transfer_encoding == 'chunked': async for chunk in response.content.iter_chunked(1024): try: decoded_chunk = chunk.decode('utf-8') + print(f"Decoded chunk: {decoded_chunk!r}") responseText += decoded_chunk except UnicodeDecodeError: print(f"Error decoding chunk: {chunk!r}") + elif 'application/json' in content_type: + responseText = await response.json() + print(f"response {responseText!r}") else: # Print the headers and JSON response print("Unexpected Transfer-Encoding: {} {} {}".format( @@ -48,18 +55,30 @@ async def test_connect_completions(session): response.json())) else: print(f"Request failed with status code {response.status}") - print( - f"baseurl {base_url} response data {extract_data(responseText)}" - ) + print(f"baseurl {base_url}") + print(f"response data {extract_data(responseText)}") except aiohttp.ClientError as e: print(f"Error: {e}") +def is_json(data): + try: + json.loads(data) + return True + except ValueError: + return False def extract_data(responseText): + if responseText == "": + return "" + if is_json(responseText): + return responseText reply = "" for data in responseText.split("\n\n"): if data.startswith('data: '): content = data[6:] + if content == "[DONE]": + print("DONE") + break try: json_data = json.loads(content) choices = json_data["choices"] diff --git a/vllm/entrypoints/disagg_connector.py b/vllm/entrypoints/disagg_connector.py index 48c8b2b2abc66..71beeaf1e6083 100644 --- a/vllm/entrypoints/disagg_connector.py +++ b/vllm/entrypoints/disagg_connector.py @@ -6,20 +6,23 @@ import uuid # from fastapi.lifespan import Lifespan from asyncio import Queue from contextlib import asynccontextmanager +from typing import AsyncGenerator import uvicorn import zmq import zmq.asyncio from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from vllm.logger import init_logger -# default prefill and decode url -url_prefill = "tcp://localhost:8110" +# default prefill and decode addr +fastapi_port = 8001 +prefill_addr = "ipc://localhost:7010" socket_prefill_num = 5 -url_decode = "tcp://localhost:8220" +decode_addr = "ipc://localhost:7020" socket_decode_num = 5 +context_type_json = "application/json" # Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) logger = init_logger('vllm.entrypoints.connect') @@ -77,16 +80,44 @@ async def execute_task_async(route: str, headers: dict, request: dict, while True: logger.info("Waiting for reply") [contentType, reply] = await sock.recv_multipart() - logger.info("Received result: %s, %s", contentType, reply) - reply = reply.decode() - yield f"{reply}" - if "[DONE]" in reply: + contentType_str = contentType.decode() + reply_str = reply.decode() + logger.info("Received result: %s, %s", contentType_str, reply_str) + yield (contentType_str, reply_str) + if context_type_json == contentType_str: + logger.info("Received %s message, return socket", + contentType_str) + break + if "[DONE]" in reply_str: logger.info("Received stop signal, return socket") break finally: await sockets.put(sock) +async def generate_stream_response(fisrt_reply: str, + generator: AsyncGenerator): + yield fisrt_reply + async for _, reply in generator: + yield reply + + +async def decode(route: str, header: dict, original_request_data: dict): + logger.info("start decode") + generator = execute_task_async(route, header, original_request_data, + app.state.sockets_decode) + logger.info("finish decode") + + async for contentType, reply in generator: + logger.info("contentType: %s, reply: %s", contentType, reply) + if context_type_json == contentType: + return JSONResponse(reply) + else: + return StreamingResponse(generate_stream_response( + reply, generator), + media_type="text/event-stream") + + @app.post('/v1/connect/completions') async def chat_completions(request: Request): try: @@ -108,11 +139,9 @@ async def chat_completions(request: Request): app.state.sockets_prefill): continue - # return decode - return StreamingResponse(execute_task_async(route, header, - original_request_data, - app.state.sockets_decode), - media_type="text/event-stream") + logger.info("finish prefill start decode") + response = await decode(route, header, original_request_data) + return response except Exception as e: import sys @@ -127,13 +156,14 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None: logger.info("vLLM Disaggregate Connector start %s %s", args, uvicorn_kwargs) logger.info(args.prefill_addr) - - app.state.prefill_addr = (f"tcp://{args.prefill_addr}" if args.prefill_addr - is not None else url_prefill) - app.state.decode_addr = (f"tcp://{args.decode_addr}" - if args.decode_addr is not None else url_decode) - logger.info("start connect url_prefill: %s url_decode: %s", - app.state.prefill_addr, app.state.decode_addr) + app.state.port = args.port if args.port is not None else fastapi_port + app.state.prefill_addr = (f"ipc://{args.prefill_addr}" if args.prefill_addr + is not None else decode_addr) + app.state.decode_addr = (f"ipc://{args.decode_addr}" + if args.decode_addr is not None else decode_addr) + logger.info( + "start connect prefill_addr: %s decode_addr: %s zmq server port: %s", + app.state.prefill_addr, app.state.decode_addr, app.state.port) def signal_handler(*_) -> None: # Interrupt server on sigterm while initializing @@ -141,11 +171,11 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None: signal.signal(signal.SIGTERM, signal_handler) # init uvicorn server - config = uvicorn.Config(app, host="0.0.0.0", port=8001) + config = uvicorn.Config(app, host="0.0.0.0", port=app.state.port) server = uvicorn.Server(config) await server.serve() if __name__ == "__main__": # url = 'tcp://127.0.0.1:5555' - uvicorn.run(app, host="0.0.0.0", port=8001) + uvicorn.run(app, host="0.0.0.0", port=fastapi_port) diff --git a/vllm/entrypoints/launcher.py b/vllm/entrypoints/launcher.py index 58a0b749e28dd..e9aa540d6aa22 100644 --- a/vllm/entrypoints/launcher.py +++ b/vllm/entrypoints/launcher.py @@ -82,22 +82,22 @@ async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None: """Server routine""" logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg, zmq_server_port) - url_worker = "inproc://workers" - url_client = f"tcp://0.0.0.0:{zmq_server_port}" + workers_addr = "inproc://workers" + clients_addr = f"ipc://127.0.0.1:{zmq_server_port}" # Prepare our context and sockets context = zmq.asyncio.Context() # Socket to talk to clients clients = context.socket(zmq.ROUTER) - clients.bind(url_client) - logger.info("ZMQ Server ROUTER started at %s", url_client) + clients.bind(clients_addr) + logger.info("ZMQ Server ROUTER started at %s", clients_addr) # Socket to talk to workers workers = context.socket(zmq.DEALER) - workers.bind(url_worker) - logger.info("ZMQ Worker DEALER started at %s", url_worker) + workers.bind(workers_addr) + logger.info("ZMQ Worker DEALER started at %s", workers_addr) tasks = [ - asyncio.create_task(worker_routine(url_worker, app, context, i)) + asyncio.create_task(worker_routine(workers_addr, app, context, i)) for i in range(5) ] proxy_task = asyncio.to_thread(zmq.proxy, clients, workers) diff --git a/vllm/entrypoints/openai/connect_worker.py b/vllm/entrypoints/openai/connect_worker.py index b5665bd9e742d..adefac87c0921 100644 --- a/vllm/entrypoints/openai/connect_worker.py +++ b/vllm/entrypoints/openai/connect_worker.py @@ -53,7 +53,7 @@ def bytes_to_headers(bytes_data: bytes) -> httpx.Headers: headers_dict = json.loads(bytes_data.decode()) return httpx.Headers(headers_dict) -async def worker_routine(worker_url: str, app: FastAPI, +async def worker_routine(worker_addr: str, app: FastAPI, context: zmq.asyncio.Context, i: int = 0): """Worker routine""" try: @@ -61,8 +61,8 @@ async def worker_routine(worker_url: str, app: FastAPI, socket = context.socket(zmq.DEALER) worker_identity = f"worker-{i}-{uuid.uuid4()}" socket.setsockopt(zmq.IDENTITY, worker_identity.encode()) - socket.connect(worker_url) - logger.info("%s started at %s", worker_identity, worker_url) + socket.connect(worker_addr) + logger.info("%s started at %s", worker_identity, worker_addr) while True: identity, url, header, body = await socket.recv_multipart() logger.info("worker-%d Received request identity: [ %s ]", @@ -81,15 +81,16 @@ async def worker_routine(worker_url: str, app: FastAPI, createRequest = create_request(url_str, "POST", body_json, headers) generator = await create_completion(app, completionRequest, createRequest) - logger.info("worker-%d Received response: [ %s ]", i, generator) + context_type_json = b"application/json" if isinstance(generator, ErrorResponse): content = generator.model_dump_json() context_json = json.loads(content) context_json.append("status_code", generator.code) - await socket.send_multipart([identity, b"application/json", + await socket.send_multipart([identity, context_type_json, json.dumps(context_json).encode('utf-8')]) elif isinstance(generator, CompletionResponse): - await socket.send_multipart([identity, b"application/json", + await socket.send_multipart([identity, + context_type_json, json.dumps(generator.model_dump()).encode('utf-8')]) else: async for chunk in generator: