Signed-off-by: inkcherry <mingzhi.liu@amd.com>
This commit is contained in:
inkcherry 2025-11-20 07:22:02 +00:00
parent 245b71a891
commit 4f592ae696

View File

@ -1,35 +1,35 @@
import argparse # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import copy
import logging import logging
import os import os
import re
import socket import socket
import threading
import uuid import uuid
import msgpack import msgpack
import zmq import zmq
import copy
import threading
from quart import Quart, make_response, request from quart import Quart, make_response, request
import re
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from typing import Dict,List
import asyncio
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
import aiohttp import aiohttp
prefill_instances = [] prefill_instances = []
decode_instances = [] decode_instances = []
request_nums = 0 request_nums = 0
app = Quart(__name__) app = Quart(__name__)
yield_chunk = set() yield_chunk = set()
IP_PORT_PATTERN = re.compile(r'//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)') IP_PORT_PATTERN = re.compile(r"//(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}):(\d+)")
from itertools import count
TRANSFER_TYPE = None
TRANSFER_TYPE=None
def _append_whole_dict_unique(target_list, data_dict): def _append_whole_dict_unique(target_list, data_dict):
new_filtered = {k: v for k, v in data_dict.items() if k != "index"} new_filtered = {k: v for k, v in data_dict.items() if k != "index"}
for existed in target_list: for existed in target_list:
@ -44,39 +44,42 @@ def _append_whole_dict_unique(target_list, data_dict):
if TRANSFER_TYPE is None: if TRANSFER_TYPE is None:
TRANSFER_TYPE = transfer_mode TRANSFER_TYPE = transfer_mode
logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE) logger.info("SET TRANSFER TYPE TO %s", TRANSFER_TYPE)
elif TRANSFER_TYPE != transfer_mode: elif transfer_mode != TRANSFER_TYPE:
raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}") raise ValueError(f"mismatched transfer mode {TRANSFER_TYPE} vs {transfer_mode}")
return True return True
_list_lock = threading.RLock() _list_lock = threading.RLock()
def _listen_for_register(hostname, port): def _listen_for_register(hostname, port):
context = zmq.Context() context = zmq.Context()
router_socket = context.socket(zmq.ROUTER) router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}") router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() poller = zmq.Poller()
poller.register(router_socket,zmq.POLLIN) poller.register(router_socket, zmq.POLLIN)
global prefill_instances global prefill_instances
global decode_instances global decode_instances
while True: while True:
socks = dict(poller.poll()) socks = dict(poller.poll())
if router_socket in socks: if router_socket in socks:
remote_addr, msg = router_socket.recv_multipart()
remote_addr,msg = router_socket.recv_multipart()
data = msgpack.loads(msg) data = msgpack.loads(msg)
if data['type'] == "HELLO": if data["type"] == "HELLO":
pass pass
elif data['type'] == "register" and data['role'] == "P": elif data["type"] == "register" and data["role"] == "P":
if data['request_address'] not in prefill_instances: if data["request_address"] not in prefill_instances:
with _list_lock: with _list_lock:
_append_whole_dict_unique(prefill_instances, data) _append_whole_dict_unique(prefill_instances, data)
elif data["type"] == "register" and data['role'] == "D": elif data["type"] == "register" and data["role"] == "D":
if data['request_address'] not in decode_instances: if data["request_address"] not in decode_instances:
with _list_lock: with _list_lock:
_append_whole_dict_unique(decode_instances, data) _append_whole_dict_unique(decode_instances, data)
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
hostname = socket.gethostname() hostname = socket.gethostname()
@ -84,147 +87,198 @@ def start_service_discovery(hostname, port):
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
_listener_thread = threading.Thread( _listener_thread = threading.Thread(
target = _listen_for_register,args = (hostname, port),daemon=True target=_listen_for_register, args=(hostname, port), daemon=True
) )
_listener_thread.start() _listener_thread.start()
return _listener_thread return _listener_thread
async def send_request_to_prefill(endpoint,req_data,request_id,p_endpoint,pip,pports,selected_prefill_dp_rank):
req_data_copy =req_data
async def send_request_to_prefill(
endpoint, req_data, request_id, p_endpoint, pip, pports, selected_prefill_dp_rank
):
req_data_copy = req_data
req_data_copy['kv_transfer_params'].update({ req_data_copy["kv_transfer_params"].update(
"do_remote_decode": True, {
"do_remote_prefill": False, "do_remote_decode": True,
"remote_handshake_port": p_endpoint['handshake_port'], "do_remote_prefill": False,
"remote_notify_port":p_endpoint['notify_port'], "remote_handshake_port": p_endpoint["handshake_port"],
"remote_engine_id": None, "remote_notify_port": p_endpoint["notify_port"],
"remote_block_ids": None, "remote_engine_id": None,
"remote_host":pip , "remote_block_ids": None,
"remote_port": pports, "remote_host": pip,
}) "remote_port": pports,
}
)
req_data_copy["stream"] = False req_data_copy["stream"] = False
req_data_copy["max_tokens"] = 1 req_data_copy["max_tokens"] = 1
if "max_completion_tokens" in req_data_copy: if "max_completion_tokens" in req_data_copy:
req_data_copy["max_completion_tokens"] = 1 req_data_copy["max_completion_tokens"] = 1
if "stream_options" in req_data_copy: if "stream_options" in req_data_copy:
del req_data_copy["stream_options"] del req_data_copy["stream_options"]
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session: async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id "X-Request-Id": request_id,
} }
if selected_prefill_dp_rank is not None: if selected_prefill_dp_rank is not None:
headers['X-data-parallel-rank']=str(selected_prefill_dp_rank) headers["X-data-parallel-rank"] = str(selected_prefill_dp_rank)
async with session.post(url=endpoint, json=req_data_copy, headers=headers) as response: async with session.post(
url=endpoint, json=req_data_copy, headers=headers
) as response:
if response.status == 200: if response.status == 200:
return await response.json() return await response.json()
else: else:
raise RuntimeError("send_request_to_prefill response.status != 200,response.statuus = ",response.status) raise RuntimeError(
"send_request_to_prefill response.status != 200,response.statuus = ",
response.status,
)
async def start_decode_request(endpoint, req_data, request_id): async def start_decode_request(endpoint, req_data, request_id):
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
)
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id "X-Request-Id": request_id,
} }
response = await session.post(url=endpoint, json=req_data, headers=headers) response = await session.post(url=endpoint, json=req_data, headers=headers)
return session, response return session, response
async def stream_decode_response(session, response, request_id): async def stream_decode_response(session, response, request_id):
try: try:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024): async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes yield chunk_bytes
else: else:
raise RuntimeError(f"decode response.status != 200, status = {response.status}") raise RuntimeError(
f"decode response.status != 200, status = {response.status}"
)
finally: finally:
await session.close() await session.close()
async def send_request_to_decode(endpoint,req_data,request_id):
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)) as session:
async def send_request_to_decode(endpoint, req_data, request_id):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=6 * 6000 * 6000)
) as session:
headers = { headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}", "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id "X-Request-Id": request_id,
} }
async with session.post(url=endpoint, json=req_data, headers=headers) as response: async with session.post(
url=endpoint, json=req_data, headers=headers
) as response:
if response.status == 200: if response.status == 200:
async for chunk_bytes in response.content.iter_chunked(1024): async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
yield chunk_bytes
else: else:
raise RuntimeError("send_request_to_decode response.status != 200,response.statuus = ",response.status) raise RuntimeError(
"send_request_to_decode response.status != 200,response.statuus = ",
response.status,
)
def example_round_robin_dp_loader(request_number, dp_size): def example_round_robin_dp_loader(request_number, dp_size):
return request_nums % dp_size return request_nums % dp_size
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"]) @app.route("/v1/chat/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
global request_nums global request_nums
request_nums += 1 request_nums += 1
def extract_ip_port_fast(url): def extract_ip_port_fast(url):
return IP_PORT_PATTERN.search(url).groups() return IP_PORT_PATTERN.search(url).groups()
req_data = await request.get_json() req_data = await request.get_json()
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
prefill_instance_endpoint=None prefill_instance_endpoint = None
decode_instance_endpoint=None decode_instance_endpoint = None
pid=request_nums % len(prefill_instances) pid = request_nums % len(prefill_instances)
did=request_nums % len(decode_instances) did = request_nums % len(decode_instances)
prefill_instance_endpoint = prefill_instances[pid] prefill_instance_endpoint = prefill_instances[pid]
decode_instance_endpoint = decode_instances[did] decode_instance_endpoint = decode_instances[did]
selected_prefill_dp_rank = None
if prefill_instance_endpoint["dp_size"] > 1:
selected_prefill_dp_rank = example_round_robin_dp_loader(
request_nums // len(prefill_instance_endpoint),
prefill_instance_endpoint["dp_size"],
)
selected_prefill_dp_rank=None dip, dport = extract_ip_port_fast(decode_instance_endpoint["request_address"])
if prefill_instance_endpoint['dp_size']>1: ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
selected_prefill_dp_rank=example_round_robin_dp_loader(request_nums//len(prefill_instance_endpoint),prefill_instance_endpoint['dp_size'])
dip,dport= extract_ip_port_fast(decode_instance_endpoint['request_address'])
ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address'])
req_data_to_prefill = copy.deepcopy(req_data) req_data_to_prefill = copy.deepcopy(req_data)
req_data_to_prefill['kv_transfer_params']={} req_data_to_prefill["kv_transfer_params"] = {}
req_data['kv_transfer_params']={} req_data["kv_transfer_params"] = {}
req_data_to_prefill['kv_transfer_params']['remote_dp_size']=decode_instance_endpoint['dp_size'] req_data_to_prefill["kv_transfer_params"]["remote_dp_size"] = (
req_data_to_prefill['kv_transfer_params']['remote_tp_size']=decode_instance_endpoint['tp_size'] decode_instance_endpoint["dp_size"]
)
req_data_to_prefill["kv_transfer_params"]["remote_tp_size"] = (
decode_instance_endpoint["tp_size"]
send_prefill_task = asyncio.create_task(send_request_to_prefill(prefill_instance_endpoint['request_address'],req_data_to_prefill,request_id,decode_instance_endpoint,dip,dport,selected_prefill_dp_rank))
ip, port = extract_ip_port_fast(prefill_instance_endpoint['request_address'])
req_data['max_tokens'] -= 1
req_data['kv_transfer_params'] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint['handshake_port'],
"remote_notify_port":prefill_instance_endpoint['notify_port'],
"remote_engine_id":None,
"remote_block_ids":None,
"remote_host":ip ,
"remote_port": port,
}
if TRANSFER_TYPE =="READ":
#In read mode, prefill and decode are executed serially.
prefill_response=await send_prefill_task
req_data['kv_transfer_params']['remote_engine_id']=prefill_response['kv_transfer_params']['remote_engine_id']
req_data['kv_transfer_params']['remote_block_ids']=prefill_response['kv_transfer_params']['remote_block_ids']
req_data['kv_transfer_params']['remote_dp_size'] = prefill_instance_endpoint['dp_size']
req_data['kv_transfer_params']['remote_tp_size'] = prefill_instance_endpoint['tp_size']
if selected_prefill_dp_rank is not None:
req_data['kv_transfer_params']['remote_dp_rank'] = selected_prefill_dp_rank
decode_request_task = asyncio.create_task(
start_decode_request(decode_instance_endpoint['request_address'], req_data, request_id)
) )
send_prefill_task = asyncio.create_task(
send_request_to_prefill(
prefill_instance_endpoint["request_address"],
req_data_to_prefill,
request_id,
decode_instance_endpoint,
dip,
dport,
selected_prefill_dp_rank,
)
)
ip, port = extract_ip_port_fast(prefill_instance_endpoint["request_address"])
req_data["max_tokens"] -= 1
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
"remote_handshake_port": prefill_instance_endpoint["handshake_port"],
"remote_notify_port": prefill_instance_endpoint["notify_port"],
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": ip,
"remote_port": port,
}
if TRANSFER_TYPE == "READ":
# In read mode, prefill and decode are executed serially.
prefill_response = await send_prefill_task
req_data["kv_transfer_params"]["remote_engine_id"] = prefill_response[
"kv_transfer_params"
]["remote_engine_id"]
req_data["kv_transfer_params"]["remote_block_ids"] = prefill_response[
"kv_transfer_params"
]["remote_block_ids"]
req_data["kv_transfer_params"]["remote_dp_size"] = prefill_instance_endpoint[
"dp_size"
]
req_data["kv_transfer_params"]["remote_tp_size"] = prefill_instance_endpoint[
"tp_size"
]
if selected_prefill_dp_rank is not None:
req_data["kv_transfer_params"]["remote_dp_rank"] = selected_prefill_dp_rank
decode_request_task = asyncio.create_task(
start_decode_request(
decode_instance_endpoint["request_address"], req_data, request_id
)
)
session, decode_response = await decode_request_task session, decode_response = await decode_request_task
stream_generator = stream_decode_response(session, decode_response, request_id) stream_generator = stream_decode_response(session, decode_response, request_id)
@ -234,11 +288,12 @@ async def handle_request():
print(e) print(e)
pass pass
if __name__ == '__main__':
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 36367) t = start_service_discovery("0.0.0.0", 36367)
app.debug = True app.debug = True
app.config['BODY_TIMEOUT'] = 360000 app.config["BODY_TIMEOUT"] = 360000
app.config['RESPONSE_TIMEOUT'] = 360000 app.config["RESPONSE_TIMEOUT"] = 360000
app.run(host="0.0.0.0", port=10001) app.run(host="0.0.0.0", port=10001)
t.join() t.join()