# SPDX-License-Identifier: Apache-2.0 import os import socket import threading import uuid import aiohttp import msgpack import zmq from quart import Quart, make_response, request count = 0 prefill_instances: dict[str, str] = {} # http_address: zmq_address decode_instances: dict[str, str] = {} # http_address: zmq_address prefill_cv = threading.Condition() decode_cv = threading.Condition() def _listen_for_register(poller, router_socket): while True: socks = dict(poller.poll()) if router_socket in socks: remote_address, message = router_socket.recv_multipart() # data: {"type": "P", "http_address": "ip:port", # "zmq_address": "ip:port"} data = msgpack.loads(message) if data["type"] == "P": global prefill_instances global prefill_cv with prefill_cv: prefill_instances[data["http_address"]] = data["zmq_address"] elif data["type"] == "D": global decode_instances global decode_cv with decode_cv: decode_instances[data["http_address"]] = data["zmq_address"] else: print( "Unexpected, Received message from %s, data: %s", remote_address, data, ) def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() if port == 0: raise ValueError("Port cannot be 0") context = zmq.Context() router_socket = context.socket(zmq.ROUTER) router_socket.bind(f"tcp://{hostname}:{port}") poller = zmq.Poller() poller.register(router_socket, zmq.POLLIN) _listener_thread = threading.Thread( target=_listen_for_register, args=[poller, router_socket], daemon=True ) _listener_thread.start() return _listener_thread AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60) app = Quart(__name__) def random_uuid() -> str: return str(uuid.uuid4().hex) async def forward_request(url, data, request_id): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "X-Request-Id": request_id, } async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: if True: async for chunk_bytes in response.content.iter_chunked(1024): yield chunk_bytes else: content = await response.read() yield content @app.route("/v1/completions", methods=["POST"]) async def handle_request(): try: original_request_data = await request.get_json() prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill prefill_request["max_tokens"] = 1 global count global prefill_instances global prefill_cv with prefill_cv: prefill_list = list(prefill_instances.items()) prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] global decode_instances global decode_cv with decode_cv: decode_list = list(decode_instances.items()) decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] print( f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, " f"ZMQ:{decode_zmq_addr}]" ) count += 1 request_id = ( f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" f"{decode_zmq_addr}_{random_uuid()}" ) # finish prefill async for _ in forward_request( f"http://{prefill_addr}/v1/completions", prefill_request, request_id ): continue # return decode generator = forward_request( f"http://{decode_addr}/v1/completions", original_request_data, request_id ) response = await make_response(generator) response.timeout = None return response except Exception as e: import sys import traceback exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server") print(e) print("".join(traceback.format_exception(*exc_info))) if __name__ == "__main__": t = start_service_discovery("0.0.0.0", 30001) app.run(host="0.0.0.0", port=10001) t.join()