diff --git a/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py new file mode 100644 index 0000000000000..ea3127df57824 --- /dev/null +++ b/tests/v1/kv_connector/moriio_integration/toy_proxy_server.py @@ -0,0 +1,242 @@ +import argparse +import logging +import os +import socket +import uuid +import msgpack +import zmq +import copy +import threading +from quart import Quart, make_response, request +import re +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse +from typing import Dict,List +import asyncio +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +import aiohttp +prefill_instances = [] +decode_instances = [] +request_nums = 0 +app = Quart(__name__) + +yield_chunk = set() +IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)') + +from itertools import count + + + +TRANSFER_TYPE=None +def _append_whole_dict_unique(target_list, data_dict): + new_filtered = {k: v for k, v in data_dict.items() if k != "index"} + for existed in target_list: + existed_filtered = {k: v for k, v in existed.items() if k != "index"} + if existed_filtered == new_filtered: + return False + print("!!APPEND!!", data_dict) + target_list.append(data_dict) + transfer_mode = data_dict.get("transfer_mode", "unknown") + global TRANSFER_TYPE + + if TRANSFER_TYPE is None: + TRANSFER_TYPE = transfer_mode + logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) + elif TRANSFER_TYPE != transfer_mode: + raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") + + return True +_list_lock = threading.RLock() + +def _listen_for_register(hostname, port): + context = zmq.Context() + router_socket = context.socket(zmq.ROUTER) + router_socket.bind(f"tcp://{hostname}:{port}") + poller = zmq.Poller() + poller.register(router_socket,zmq.POLLIN) + global prefill_instances + global decode_instances + + while True: + socks = dict(poller.poll()) + if router_socket in socks: + + remote_addr,msg = router_socket.recv_multipart() + data = msgpack.loads(msg) + if data['type'] == "HELLO": + pass + elif data['type'] == "register" and data['role'] == "P": + if data['request_address'] not in prefill_instances: + with _list_lock: + _append_whole_dict_unique(prefill_instances, data) + + elif data["type"] == "register" and data['role'] == "D": + if data['request_address'] not in decode_instances: + with _list_lock: + _append_whole_dict_unique(decode_instances, data) + +def start_service_discovery(hostname, port): + if not hostname: + hostname = socket.gethostname() + if port == 0: + raise ValueError("Port cannot be 0") + + _listener_thread = threading.Thread( + target = _listen_for_register,args = (hostname, port),daemon=True + ) + _listener_thread.start() + return _listener_thread + +async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank): + req_data_copy =req_data + + + req_data_copy['kv_transfer_params'].update({ + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_handshake_port": p_endpoint['handshake_port'], + "remote_notify_port":p_endpoint['notify_port'], + "remote_engine_id": None, + "remote_block_ids": None, + "remote_host":pip , + "remote_port": pports, + }) + req_data_copy["stream"] = False + req_data_copy["max_tokens"] = 1 + if "max_completion_tokens" in req_data_copy: + req_data_copy["max_completion_tokens"] = 1 + if "stream_options" in req_data_copy: + del req_data_copy["stream_options"] + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + if selected_prefill_dp_rank is not None: + headers['X-data-parallel-rank']=str(selected_prefill_dp_rank) + async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response: + if response.status == 200: + return await response.json() + + else: + raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status) +async def start_decode_request(endpoint, req_data, request_id): + session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + response = await session.post(url=endpoint, json=req_data, headers=headers) + return session, response + +async def stream_decode_response(session, response, request_id): + try: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + + yield chunk_bytes + else: + raise RuntimeError(f"decode response.status != 200, status = {response.status}") + finally: + await session.close() +async def send_request_to_decode(endpoint,req_data,request_id): + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: + headers = { + "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", + "X-Request-Id": request_id + } + async with session.post(url=endpoint, json=req_data, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content.iter_chunked(1024): + + yield chunk_bytes + else: + raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status) +def example_round_robin_dp_loader(request_number, dp_size): + return request_nums % dp_size + +@app.route("/v1/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) +async def handle_request(): + try: + + global request_nums + request_nums += 1 + def extract_ip_port_fast(url): + return IP_PORT_PATTERN.search(url).groups() + req_data = await request.get_json() + request_id = str(uuid.uuid4()) + + prefill_instance_endpoint=None + decode_instance_endpoint=None + + pid=request_nums % len(prefill_instances) + did=request_nums % len(decode_instances) + prefill_instance_endpoint = prefill_instances[pid] + decode_instance_endpoint = decode_instances[did] + + + selected_prefill_dp_rank=None + if prefill_instance_endpoint['dp_size']>1: + selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size']) + + dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address']) + ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) + + req_data_to_prefill = copy.deepcopy(req_data) + req_data_to_prefill['kv_transfer_params']={} + req_data['kv_transfer_params']={} + req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size'] + req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size'] + + + + send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank)) + ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address']) + + + req_data['max_tokens'] -= 1 + + req_data['kv_transfer_params'] = { + "do_remote_decode": False, + "do_remote_prefill": True, + "remote_handshake_port": prefill_instance_endpoint['handshake_port'], + "remote_notify_port":prefill_instance_endpoint['notify_port'], + "remote_engine_id":None, + "remote_block_ids":None, + "remote_host":ip , + "remote_port": port, + } + if TRANSFER_TYPE =="READ": + #In read mode, prefill and decode are executed serially. + prefill_response=await send_prefill_task + req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id'] + req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids'] + + req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size'] + req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size'] + + + decode_request_task = asyncio.create_task( + start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id) + ) + + + session, decode_response = await decode_request_task + stream_generator = stream_decode_response(session, decode_response, request_id) + response = await make_response(stream_generator) + return response + except Exception as e: + print(e) + pass + +if __name__ == '__main__': + t = start_service_discovery("0.0.0.0", 36367) + app.debug = True + app.config['BODY_TIMEOUT'] = 360000 + app.config['RESPONSE_TIMEOUT'] = 360000 + + app.run(host="0.0.0.0", port=10001) + t.join() \ No newline at end of file