mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-22 07:45:01 +08:00
155 lines
4.6 KiB
Python
155 lines
4.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
import socket
|
|
import threading
|
|
import uuid
|
|
|
|
import aiohttp
|
|
import msgpack
|
|
import zmq
|
|
from quart import Quart, make_response, request
|
|
|
|
count = 0
|
|
prefill_instances: dict[str, str] = {} # http_address: zmq_address
|
|
decode_instances: dict[str, str] = {} # http_address: zmq_address
|
|
|
|
prefill_cv = threading.Condition()
|
|
decode_cv = threading.Condition()
|
|
|
|
|
|
def _listen_for_register(poller, router_socket):
|
|
while True:
|
|
socks = dict(poller.poll())
|
|
if router_socket in socks:
|
|
remote_address, message = router_socket.recv_multipart()
|
|
# data: {"type": "P", "http_address": "ip:port",
|
|
# "zmq_address": "ip:port"}
|
|
data = msgpack.loads(message)
|
|
if data["type"] == "P":
|
|
global prefill_instances
|
|
global prefill_cv
|
|
with prefill_cv:
|
|
prefill_instances[data["http_address"]] = data["zmq_address"]
|
|
elif data["type"] == "D":
|
|
global decode_instances
|
|
global decode_cv
|
|
with decode_cv:
|
|
decode_instances[data["http_address"]] = data["zmq_address"]
|
|
else:
|
|
print(
|
|
"Unexpected, Received message from %s, data: %s",
|
|
remote_address,
|
|
data,
|
|
)
|
|
|
|
|
|
def start_service_discovery(hostname, port):
|
|
if not hostname:
|
|
hostname = socket.gethostname()
|
|
if port == 0:
|
|
raise ValueError("Port cannot be 0")
|
|
|
|
context = zmq.Context()
|
|
router_socket = context.socket(zmq.ROUTER)
|
|
router_socket.bind(f"tcp://{hostname}:{port}")
|
|
|
|
poller = zmq.Poller()
|
|
poller.register(router_socket, zmq.POLLIN)
|
|
|
|
_listener_thread = threading.Thread(
|
|
target=_listen_for_register, args=[poller, router_socket], daemon=True
|
|
)
|
|
_listener_thread.start()
|
|
return _listener_thread
|
|
|
|
|
|
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
|
|
|
|
app = Quart(__name__)
|
|
|
|
|
|
def random_uuid() -> str:
|
|
return str(uuid.uuid4().hex)
|
|
|
|
|
|
async def forward_request(url, data, request_id):
|
|
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
|
headers = {
|
|
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
|
"X-Request-Id": request_id,
|
|
}
|
|
async with session.post(url=url, json=data, headers=headers) as response:
|
|
if response.status == 200:
|
|
if True:
|
|
async for chunk_bytes in response.content.iter_chunked(1024):
|
|
yield chunk_bytes
|
|
else:
|
|
content = await response.read()
|
|
yield content
|
|
|
|
|
|
@app.route("/v1/completions", methods=["POST"])
|
|
async def handle_request():
|
|
try:
|
|
original_request_data = await request.get_json()
|
|
|
|
prefill_request = original_request_data.copy()
|
|
# change max_tokens = 1 to let it only do prefill
|
|
prefill_request["max_tokens"] = 1
|
|
|
|
global count
|
|
global prefill_instances
|
|
global prefill_cv
|
|
with prefill_cv:
|
|
prefill_list = list(prefill_instances.items())
|
|
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
|
|
|
|
global decode_instances
|
|
global decode_cv
|
|
with decode_cv:
|
|
decode_list = list(decode_instances.items())
|
|
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
|
|
|
|
print(
|
|
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
|
|
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
|
|
f"ZMQ:{decode_zmq_addr}]"
|
|
)
|
|
count += 1
|
|
|
|
request_id = (
|
|
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
|
|
f"{decode_zmq_addr}_{random_uuid()}"
|
|
)
|
|
|
|
# finish prefill
|
|
async for _ in forward_request(
|
|
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
|
|
):
|
|
continue
|
|
|
|
# return decode
|
|
generator = forward_request(
|
|
f"http://{decode_addr}/v1/completions", original_request_data, request_id
|
|
)
|
|
response = await make_response(generator)
|
|
response.timeout = None
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
import sys
|
|
import traceback
|
|
|
|
exc_info = sys.exc_info()
|
|
print("Error occurred in disagg prefill proxy server")
|
|
print(e)
|
|
print("".join(traceback.format_exception(*exc_info)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
t = start_service_discovery("0.0.0.0", 30001)
|
|
app.run(host="0.0.0.0", port=10001)
|
|
t.join()
|