From 4f592ae696dbf5553f7d5d2e79807cc4ac740857 Mon Sep 17 00:00:00 2001 From: inkcherry Date: Thu, 20 Nov 2025 07:22:02 +0000 Subject: [PATCH] format Signed-off-by: inkcherry --- .../moriio_integration/toy_proxy_server.py | 267 +++++++++++------- 1 file changed, 161 insertions(+), 106 deletions(-) diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py index e375b84bd4053..3091d98366c0a 100644 --- a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -1,35 +1,35 @@ -import argparse +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import asyncio +import copy import logging import os +import re import socket +import threading import uuid + import msgpack import zmq -import copy -import threading from quart import Quart, make_response, request -import re -from fastapi import FastAPI, Request -from fastapi.responses import StreamingResponse -from typing import Dict,List -import asyncio + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) import aiohttp + prefill_instances = [] -decode_instances = [] +decode_instances = [] request_nums = 0 app = Quart(__name__) yield_chunk = set() -IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)') - -from itertools import count +IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)") +TRANSFER_TYPE = None + -TRANSFER_TYPE=None def _append_whole_dict_unique(target_list, data_dict): new_filtered = {k: v for k, v in data_dict.items() if k != "index"} for existed in target_list: @@ -44,39 +44,42 @@ def _append_whole_dict_unique(target_list, data_dict): if TRANSFER_TYPE is None: TRANSFER_TYPE = transfer_mode logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) - elif TRANSFER_TYPE != transfer_mode: + elif transfer_mode != TRANSFER_TYPE: raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") - + return True + + _list_lock = threading.RLock() + def _listen_for_register(hostname, port): context = zmq.Context() router_socket = context.socket(zmq.ROUTER) router_socket.bind(f"tcp://{hostname}:{port}") poller = zmq.Poller() - poller.register(router_socket,zmq.POLLIN) + poller.register(router_socket, zmq.POLLIN) global prefill_instances global decode_instances while True: socks = dict(poller.poll()) if router_socket in socks: - - remote_addr,msg = router_socket.recv_multipart() + remote_addr, msg = router_socket.recv_multipart() data = msgpack.loads(msg) - if data['type'] == "HELLO": + if data["type"] == "HELLO": pass - elif data['type'] == "register" and data['role'] == "P": - if data['request_address'] not in prefill_instances: + elif data["type"] == "register" and data["role"] == "P": + if data["request_address"] not in prefill_instances: with _list_lock: _append_whole_dict_unique(prefill_instances, data) - elif data["type"] == "register" and data['role'] == "D": - if data['request_address'] not in decode_instances: + elif data["type"] == "register" and data["role"] == "D": + if data["request_address"] not in decode_instances: with _list_lock: _append_whole_dict_unique(decode_instances, data) + def start_service_discovery(hostname, port): if not hostname: hostname = socket.gethostname() @@ -84,147 +87,198 @@ def start_service_discovery(hostname, port): raise ValueError("Port cannot be 0") _listener_thread = threading.Thread( - target = _listen_for_register,args = (hostname, port),daemon=True + target=_listen_for_register, args=(hostname, port), daemon=True ) _listener_thread.start() return _listener_thread -async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank): - req_data_copy =req_data - - - req_data_copy['kv_transfer_params'].update({ - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_handshake_port": p_endpoint['handshake_port'], - "remote_notify_port":p_endpoint['notify_port'], - "remote_engine_id": None, - "remote_block_ids": None, - "remote_host":pip , - "remote_port": pports, - }) + +async def send_request_to_prefill( + endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank +): + req_data_copy = req_data + + req_data_copy["kv_transfer_params"].update( + { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": p_endpoint["handshake_port"], + "remote_notify_port": p_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": pip, + "remote_port": pports, + } + ) req_data_copy["stream"] = False req_data_copy["max_tokens"] = 1 if "max_completion_tokens" in req_data_copy: req_data_copy["max_completion_tokens"] = 1 if "stream_options" in req_data_copy: del req_data_copy["stream_options"] - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } if selected_prefill_dp_rank is not None: - headers['X-data-parallel-rank']=str(selected_prefill_dp_rank) - async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response: + headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank) + async with session.post( + url=endpoint, json=req_data_copy, headers=headers + ) as response: if response.status == 200: return await response.json() - + else: - raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status) + raise RuntimeError( + "send_request_to_prefill response.status != 200,response.statuus = ", + response.status, + ) + + async def start_decode_request(endpoint, req_data, request_id): - session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } response = await session.post(url=endpoint, json=req_data, headers=headers) return session, response + async def stream_decode_response(session, response, request_id): try: if response.status == 200: async for chunk_bytes in response.content.iter_chunked(1024): - yield chunk_bytes else: - raise RuntimeError(f"decode response.status != 200, status = {response.status}") + raise RuntimeError( + f"decode response.status != 200, status = {response.status}" + ) finally: await session.close() -async def send_request_to_decode(endpoint,req_data,request_id): - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + + +async def send_request_to_decode(endpoint, req_data, request_id): + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000) + ) as session: headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", - "X-Request-Id": request_id + "X-Request-Id": request_id, } - async with session.post(url=endpoint, json=req_data, headers=headers) as response: + async with session.post( + url=endpoint, json=req_data, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content.iter_chunked(1024): - - yield chunk_bytes + yield chunk_bytes else: - raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status) + raise RuntimeError( + "send_request_to_decode response.status != 200,response.statuus = ", + response.status, + ) + + def example_round_robin_dp_loader(request_number, dp_size): return request_nums % dp_size + @app.route("/v1/completions", methods=["POST"]) @app.route("/v1/chat/completions", methods=["POST"]) async def handle_request(): try: - global request_nums request_nums += 1 + def extract_ip_port_fast(url): return IP_PORT_PATTERN.search(url).groups() + req_data = await request.get_json() request_id = str(uuid.uuid4()) - prefill_instance_endpoint=None - decode_instance_endpoint=None - - pid=request_nums % len(prefill_instances) - did=request_nums % len(decode_instances) + prefill_instance_endpoint = None + decode_instance_endpoint = None + + pid = request_nums % len(prefill_instances) + did = request_nums % len(decode_instances) prefill_instance_endpoint = prefill_instances[pid] decode_instance_endpoint = decode_instances[did] - - - selected_prefill_dp_rank=None - if prefill_instance_endpoint['dp_size']>1: - selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size']) - - dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address']) - ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) - + + selected_prefill_dp_rank = None + if prefill_instance_endpoint["dp_size"] > 1: + selected_prefill_dp_rank = example_round_robin_dp_loader( + request_nums // len(prefill_instance_endpoint), + prefill_instance_endpoint["dp_size"], + ) + + dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"]) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + req_data_to_prefill = copy.deepcopy(req_data) - req_data_to_prefill['kv_transfer_params']={} - req_data['kv_transfer_params']={} - req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size'] - req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size'] - - - - send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank)) - ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) - - - req_data['max_tokens'] -= 1 - - req_data['kv_transfer_params'] = { + req_data_to_prefill["kv_transfer_params"] = {} + req_data["kv_transfer_params"] = {} + req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = ( + decode_instance_endpoint["dp_size"] + ) + req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = ( + decode_instance_endpoint["tp_size"] + ) + + send_prefill_task = asyncio.create_task( + send_request_to_prefill( + prefill_instance_endpoint["request_address"], + req_data_to_prefill, + request_id, + decode_instance_endpoint, + dip, + dport, + selected_prefill_dp_rank, + ) + ) + ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"]) + + req_data["max_tokens"] -= 1 + + req_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, - "remote_handshake_port": prefill_instance_endpoint['handshake_port'], - "remote_notify_port":prefill_instance_endpoint['notify_port'], - "remote_engine_id":None, - "remote_block_ids":None, - "remote_host":ip , + "remote_handshake_port": prefill_instance_endpoint["handshake_port"], + "remote_notify_port": prefill_instance_endpoint["notify_port"], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host": ip, "remote_port": port, } - if TRANSFER_TYPE =="READ": - #In read mode, prefill and decode are executed serially. - prefill_response=await send_prefill_task - req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id'] - req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids'] - - req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size'] - req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size'] - + if TRANSFER_TYPE == "READ": + # In read mode, prefill and decode are executed serially. + prefill_response = await send_prefill_task + req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[ + "kv_transfer_params" + ]["remote_engine_id"] + req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[ + "kv_transfer_params" + ]["remote_block_ids"] + + req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[ + "dp_size" + ] + req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[ + "tp_size" + ] + if selected_prefill_dp_rank is not None: - req_data['kv_transfer_params']['remote_dp_rank'] = selected_prefill_dp_rank + req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank decode_request_task = asyncio.create_task( - start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id) + start_decode_request( + decode_instance_endpoint["request_address"], req_data, request_id + ) ) - session, decode_response = await decode_request_task stream_generator = stream_decode_response(session, decode_response, request_id) @@ -234,11 +288,12 @@ async def handle_request(): print(e) pass -if __name__ == '__main__': + +if __name__ == "__main__": t = start_service_discovery("0.0.0.0", 36367) - app.debug = True - app.config['BODY_TIMEOUT'] = 360000 - app.config['RESPONSE_TIMEOUT'] = 360000 + app.debug = True + app.config["BODY_TIMEOUT"] = 360000 + app.config["RESPONSE_TIMEOUT"] = 360000 app.run(host="0.0.0.0", port=10001) - t.join() \ No newline at end of file + t.join()