Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 07:22:02 +00:00
parent 245b71a891
commit 4f592ae696

View File

@ -1,35 +1,35 @@
import argparse
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import logging
import os
import re
import socket
import threading
import uuid
import msgpack
import zmq
import copy
import threading
from quart import Quart, make_response, request
import re
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from typing import Dict,List
import asyncio
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
import aiohttp
prefill_instances = []
decode_instances = []
decode_instances = []
request_nums = 0
app = Quart(__name__)
yield_chunk = set()
IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)')
from itertools import count
IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
TRANSFER_TYPE = None
TRANSFER_TYPE=None
def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list:
@ -44,39 +44,42 @@ def _append_whole_dict_unique(target_list, data_dict):
if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif TRANSFER_TYPE != transfer_mode:
elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
return True
_list_lock = threading.RLock()
def _listen_for_register(hostname, port):
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket,zmq.POLLIN)
poller.register(router_socket, zmq.POLLIN)
global prefill_instances
global decode_instances
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_addr,msg = router_socket.recv_multipart()
remote_addr, msg = router_socket.recv_multipart()
data = msgpack.loads(msg)
if data['type'] == "HELLO":
if data["type"] == "HELLO":
pass
elif data['type'] == "register" and data['role'] == "P":
if data['request_address'] not in prefill_instances:
elif data["type"] == "register" and data["role"] == "P":
if data["request_address"] not in prefill_instances:
with _list_lock:
_append_whole_dict_unique(prefill_instances, data)
elif data["type"] == "register" and data['role'] == "D":
if data['request_address'] not in decode_instances:
elif data["type"] == "register" and data["role"] == "D":
if data["request_address"] not in decode_instances:
with _list_lock:
_append_whole_dict_unique(decode_instances, data)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
@ -84,147 +87,198 @@ def start_service_discovery(hostname, port):
raise ValueError("Port cannot be 0")
_listener_thread = threading.Thread(
target = _listen_for_register,args = (hostname, port),daemon=True
target=_listen_for_register, args=(hostname, port), daemon=True
)
_listener_thread.start()
return _listener_thread
async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank):
req_data_copy =req_data
req_data_copy['kv_transfer_params'].update({
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_handshake_port": p_endpoint['handshake_port'],
"remote_notify_port":p_endpoint['notify_port'],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host":pip ,
"remote_port": pports,
})
async def send_request_to_prefill(
endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank
):
req_data_copy = req_data
req_data_copy["kv_transfer_params"].update(
{
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_handshake_port": p_endpoint["handshake_port"],
"remote_notify_port": p_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": pip,
"remote_port": pports,
}
)
req_data_copy["stream"] = False
req_data_copy["max_tokens"] = 1
if "max_completion_tokens" in req_data_copy:
req_data_copy["max_completion_tokens"] = 1
if "stream_options" in req_data_copy:
del req_data_copy["stream_options"]
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session:
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
"X-Request-Id": request_id,
}
if selected_prefill_dp_rank is not None:
headers['X-data-parallel-rank']=str(selected_prefill_dp_rank)
async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response:
headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank)
async with session.post(
url=endpoint, json=req_data_copy, headers=headers
) as response:
if response.status == 200:
return await response.json()
else:
raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status)
raise RuntimeError(
"send_request_to_prefill response.status != 200,response.statuus = ",
response.status,
)
async def start_decode_request(endpoint, req_data, request_id):
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000))
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
)
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
"X-Request-Id": request_id,
}
response = await session.post(url=endpoint, json=req_data, headers=headers)
return session, response
async def stream_decode_response(session, response, request_id):
try:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
raise RuntimeError(f"decode response.status != 200, status = {response.status}")
raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
)
finally:
await session.close()
async def send_request_to_decode(endpoint,req_data,request_id):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session:
async def send_request_to_decode(endpoint, req_data, request_id):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
"X-Request-Id": request_id,
}
async with session.post(url=endpoint, json=req_data, headers=headers) as response:
async with session.post(
url=endpoint, json=req_data, headers=headers
) as response:
if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
yield chunk_bytes
else:
raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status)
raise RuntimeError(
"send_request_to_decode response.status != 200,response.statuus = ",
response.status,
)
def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size
@app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request():
try:
global request_nums
request_nums += 1
def extract_ip_port_fast(url):
return IP_PORT_PATTERN.search(url).groups()
req_data = await request.get_json()
request_id = str(uuid.uuid4())
prefill_instance_endpoint=None
decode_instance_endpoint=None
pid=request_nums % len(prefill_instances)
did=request_nums % len(decode_instances)
prefill_instance_endpoint = None
decode_instance_endpoint = None
pid = request_nums % len(prefill_instances)
did = request_nums % len(decode_instances)
prefill_instance_endpoint = prefill_instances[pid]
decode_instance_endpoint = decode_instances[did]
selected_prefill_dp_rank=None
if prefill_instance_endpoint['dp_size']>1:
selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size'])
dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address'])
ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address'])
selected_prefill_dp_rank = None
if prefill_instance_endpoint["dp_size"] > 1:
selected_prefill_dp_rank = example_round_robin_dp_loader(
request_nums // len(prefill_instance_endpoint),
prefill_instance_endpoint["dp_size"],
)
dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
req_data_to_prefill = copy.deepcopy(req_data)
req_data_to_prefill['kv_transfer_params']={}
req_data['kv_transfer_params']={}
req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size']
req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size']
send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank))
ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address'])
req_data['max_tokens'] -= 1
req_data['kv_transfer_params'] = {
req_data_to_prefill["kv_transfer_params"] = {}
req_data["kv_transfer_params"] = {}
req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = (
decode_instance_endpoint["dp_size"]
)
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
decode_instance_endpoint["tp_size"]
)
send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
req_data["max_tokens"] -= 1
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint['handshake_port'],
"remote_notify_port":prefill_instance_endpoint['notify_port'],
"remote_engine_id":None,
"remote_block_ids":None,
"remote_host":ip ,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
}
if TRANSFER_TYPE =="READ":
#In read mode, prefill and decode are executed serially.
prefill_response=await send_prefill_task
req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id']
req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids']
req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size']
req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size']
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
]
req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[
"tp_size"
]
if selected_prefill_dp_rank is not None:
req_data['kv_transfer_params']['remote_dp_rank'] = selected_prefill_dp_rank
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
decode_request_task = asyncio.create_task(
start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id)
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
)
session, decode_response = await decode_request_task
stream_generator = stream_decode_response(session, decode_response, request_id)
@ -234,11 +288,12 @@ async def handle_request():
print(e)
pass
if __name__ == '__main__':
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 36367)
app.debug = True
app.config['BODY_TIMEOUT'] = 360000
app.config['RESPONSE_TIMEOUT'] = 360000
app.debug = True
app.config["BODY_TIMEOUT"] = 360000
app.config["RESPONSE_TIMEOUT"] = 360000
app.run(host="0.0.0.0", port=10001)
t.join()
t.join()