mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-12 00:37:04 +08:00
fix ThreadProxy
Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
parent
ee6607332e
commit
27c1afe88b
@ -41,7 +41,7 @@ async def test_connect_completions(session):
|
||||
async for chunk in response.content.iter_chunked(1024):
|
||||
try:
|
||||
decoded_chunk = chunk.decode('utf-8')
|
||||
print(f"Decoded chunk: {decoded_chunk!r}")
|
||||
# print(f"Decoded chunk: {decoded_chunk!r}")
|
||||
responseText += decoded_chunk
|
||||
except UnicodeDecodeError:
|
||||
print(f"Error decoding chunk: {chunk!r}")
|
||||
@ -55,6 +55,7 @@ async def test_connect_completions(session):
|
||||
response.json()))
|
||||
else:
|
||||
print(f"Request failed with status code {response.status}")
|
||||
print(f"Response : {await response.json()}")
|
||||
print(f"baseurl {base_url}")
|
||||
print(f"response data {extract_data(responseText)}")
|
||||
except aiohttp.ClientError as e:
|
||||
@ -98,7 +99,7 @@ def extract_data(responseText):
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
tasks = []
|
||||
for _ in range(1):
|
||||
for _ in range(2):
|
||||
tasks.append(test_connect_completions(session))
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import signal
|
||||
import traceback
|
||||
import uuid
|
||||
# from fastapi.lifespan import Lifespan
|
||||
from asyncio import Queue
|
||||
@ -17,12 +19,14 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# default prefill and decode addr
|
||||
time_out = 3
|
||||
fastapi_port = 8001
|
||||
prefill_addr = "ipc://localhost:7010"
|
||||
socket_prefill_num = 5
|
||||
socket_prefill_num = 20
|
||||
decode_addr = "ipc://localhost:7020"
|
||||
socket_decode_num = 5
|
||||
socket_decode_num = 20
|
||||
context_type_json = "application/json"
|
||||
context_type_error = "error"
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
logger = init_logger('vllm.entrypoints.connect')
|
||||
@ -51,7 +55,7 @@ app = FastAPI(lifespan=lifespan)
|
||||
# create async socket pool with num_sockets use ZMQ_DEALER
|
||||
async def create_socket_pool(url: str, num_sockets: int,
|
||||
zmqctx: zmq.asyncio.Context) -> Queue:
|
||||
sockets: Queue = Queue()
|
||||
sockets: Queue[zmq.Socket] = Queue()
|
||||
for i in range(num_sockets):
|
||||
sock = zmqctx.socket(zmq.DEALER)
|
||||
identity = f"worker-{i}-{uuid.uuid4()}"
|
||||
@ -66,20 +70,23 @@ async def create_socket_pool(url: str, num_sockets: int,
|
||||
# select a socket and execute task
|
||||
async def execute_task_async(route: str, headers: dict, request: dict,
|
||||
sockets: Queue):
|
||||
sock = await sockets.get()
|
||||
sock: zmq.Socket = await sockets.get()
|
||||
try:
|
||||
requestBody = json.dumps(request)
|
||||
headersJson = json.dumps(headers)
|
||||
logger.info("Sending requestBody: %s to %s with headers: %s",
|
||||
requestBody, route, headersJson)
|
||||
await sock.send_multipart(
|
||||
await asyncio.wait_for(sock.send_multipart(
|
||||
[route.encode(),
|
||||
headersJson.encode(),
|
||||
requestBody.encode()])
|
||||
requestBody.encode()]),
|
||||
timeout=time_out)
|
||||
logger.info("Sent end")
|
||||
while True:
|
||||
logger.info("Waiting for reply")
|
||||
[contentType, reply] = await sock.recv_multipart()
|
||||
[contentType,
|
||||
reply] = await asyncio.wait_for(sock.recv_multipart(),
|
||||
timeout=time_out)
|
||||
contentType_str = contentType.decode()
|
||||
reply_str = reply.decode()
|
||||
logger.info("Received result: %s, %s", contentType_str, reply_str)
|
||||
@ -91,6 +98,11 @@ async def execute_task_async(route: str, headers: dict, request: dict,
|
||||
if "[DONE]" in reply_str:
|
||||
logger.info("Received stop signal, return socket")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("Timeout, return socket: %s",
|
||||
sock.getsockopt(zmq.IDENTITY))
|
||||
yield (context_type_error, "System Error")
|
||||
finally:
|
||||
await sockets.put(sock)
|
||||
|
||||
@ -101,16 +113,30 @@ async def generate_stream_response(fisrt_reply: str,
|
||||
async for _, reply in generator:
|
||||
yield reply
|
||||
|
||||
|
||||
async def prefill(route: str, header: dict, original_request_data: dict):
|
||||
logger.info("start prefill")
|
||||
generator = execute_task_async(route, header, original_request_data,
|
||||
app.state.sockets_prefill)
|
||||
async for contentType, reply in generator:
|
||||
logger.info("contentType: %s, reply: %s", contentType, reply)
|
||||
if context_type_error == contentType:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
return True
|
||||
|
||||
async def decode(route: str, header: dict, original_request_data: dict):
|
||||
logger.info("start decode")
|
||||
generator = execute_task_async(route, header, original_request_data,
|
||||
app.state.sockets_decode)
|
||||
logger.info("finish decode")
|
||||
|
||||
async for contentType, reply in generator:
|
||||
logger.info("contentType: %s, reply: %s", contentType, reply)
|
||||
if context_type_json == contentType:
|
||||
if context_type_error == contentType:
|
||||
response = JSONResponse({"error": reply})
|
||||
response.status_code = 500
|
||||
return response
|
||||
elif context_type_json == contentType:
|
||||
return JSONResponse(reply)
|
||||
else:
|
||||
return StreamingResponse(generate_stream_response(
|
||||
@ -135,12 +161,17 @@ async def chat_completions(request: Request):
|
||||
prefill_request['max_tokens'] = 1
|
||||
route = "/v1/completions"
|
||||
# finish prefill
|
||||
async for _ in execute_task_async(route, header, prefill_request,
|
||||
app.state.sockets_prefill):
|
||||
continue
|
||||
|
||||
logger.info("finish prefill start decode")
|
||||
response = await decode(route, header, original_request_data)
|
||||
try:
|
||||
prefill_response = await prefill(route, header, prefill_request)
|
||||
if isinstance(prefill_response, JSONResponse):
|
||||
return prefill_response
|
||||
logger.info("finish prefill start decode")
|
||||
response = await decode(route, header, original_request_data)
|
||||
logger.info("finish decode")
|
||||
except Exception as e:
|
||||
logger.error("Error occurred in disagg prefill proxy server, %s",
|
||||
e)
|
||||
response = JSONResponse({"error": {"message": str(e)}})
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import Any, Optional
|
||||
import uvicorn
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
import zmq.devices
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
@ -78,37 +79,28 @@ async def serve_http(app: FastAPI,
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
def proxy(clients_addr: str, workers_addr: str,
|
||||
ctx: zmq.asyncio.Context) -> None:
|
||||
in_socket = ctx.socket(zmq.ROUTER)
|
||||
in_socket.bind(clients_addr)
|
||||
out_socket = ctx.socket(zmq.DEALER)
|
||||
out_socket.bind(workers_addr)
|
||||
try:
|
||||
zmq.proxy(in_socket, out_socket)
|
||||
except zmq.ContextTerminated:
|
||||
print("proxy terminated")
|
||||
in_socket.close()
|
||||
out_socket.close()
|
||||
|
||||
|
||||
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
|
||||
"""Server routine"""
|
||||
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
|
||||
zmq_server_port)
|
||||
workers_addr = "inproc://workers"
|
||||
# different zmq context can't communicate use inproc
|
||||
workers_addr = "ipc://workers"
|
||||
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
|
||||
# Prepare our context and sockets
|
||||
context = zmq.asyncio.Context.instance()
|
||||
try:
|
||||
tasks = [
|
||||
asyncio.create_task(worker_routine(workers_addr, app, context, i))
|
||||
for i in range(5)
|
||||
for i in range(20)
|
||||
]
|
||||
logger.info("zmq tasks: %s", tasks)
|
||||
proxy_task = asyncio.to_thread(proxy, clients_addr, workers_addr,
|
||||
context)
|
||||
await asyncio.gather(*tasks, proxy_task)
|
||||
# thread safety proxy create socket in the background:
|
||||
# https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices
|
||||
thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER)
|
||||
thread_proxy.bind_in(clients_addr)
|
||||
thread_proxy.bind_out(workers_addr)
|
||||
thread_proxy.start()
|
||||
await asyncio.gather(*tasks)
|
||||
except KeyboardInterrupt:
|
||||
print("ZMQ Server interrupted")
|
||||
except zmq.ZMQError as e:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user