refactor disagg

Signed-off-by: clark <panf2333@gmail.com>
This commit is contained in:
clark 2025-03-08 20:51:39 +08:00
parent 4f13e89143
commit 912031ceb5
10 changed files with 713 additions and 318 deletions

View File

@ -26,6 +26,14 @@ cleanup() {
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# install quart first -- required for disagg prefill proxy serve
if python3 -c "import quart" &> /dev/null; then
echo "Quart is already installed."
else
echo "Quart is not installed. Installing..."
python3 -m pip install quart
fi
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
@ -41,7 +49,6 @@ wait_for_server() {
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
--port 8100 \
--zmq-server-port 7010 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
@ -51,25 +58,13 @@ CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL_NAME \
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL_NAME \
--port 8200 \
--zmq-server-port 7011 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' &
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (via zmq port 7010), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance (via zmq port 7011)
vllm connect --port 8000 \
--prefill-addr 127.0.0.1:7010 \
--decode-addr 127.0.0.1:7011 &
# wait until prefill, decode instances and proxy are ready
wait_for_server 8000
# wait until prefill and decode instances are ready
wait_for_server 8100
wait_for_server 8200
@ -118,4 +113,4 @@ echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""
echo ""

View File

@ -0,0 +1,113 @@
#!/bin/bash
# This file demonstrates the example usage of disaggregated prefilling with ZMQ
# We will launch 2 vllm instances (1 for prefill and 1 for decode),
# and then transfer the KV cache between them.
set -xe
echo "🚧🚧 Warning: The usage of disaggregated prefill is experimental and subject to change 🚧🚧"
sleep 1
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# a function that waits vLLM connect to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
# a function that waits vLLM disagg to start
wait_for_disagg_server() {
local log_file=$1
timeout 1200 bash -c "
until grep -q 'zmq Server started at' $log_file; do
sleep 1
done" && return 0 || return 1
}
# You can also adjust --kv-ip and --kv-port for distributed inference.
# prefilling instance, which is the KV producer
CUDA_VISIBLE_DEVICES=0 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
--zmq-server-addr testipc0 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' > vllm_disagg_prefill.log 2>&1 &
# decoding instance, which is the KV consumer
CUDA_VISIBLE_DEVICES=1 vllm disagg meta-llama/Meta-Llama-3.1-8B-Instruct \
--zmq-server-addr testipc1 \
--max-model-len 100 \
--gpu-memory-utilization 0.8 \
--kv-transfer-config \
'{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' > vllm_disagg_decode.log 2>&1 &
# launch a proxy server that opens the service at port 8000
# the workflow of this proxy:
# - send the request to prefill vLLM instance (via zmq addr testipc0), change max_tokens
# to 1
# - after the prefill vLLM finishes prefill, send the request to decode vLLM
# instance (via zmq addr testipc1)
vllm connect --port 8000 \
--prefill-addr testipc0 \
--decode-addr testipc1 &
# wait until prefill, decode instances and proxy are ready
wait_for_server 8000
wait_for_disagg_server vllm_disagg_prefill.log
wait_for_disagg_server vllm_disagg_decode.log
# serve two example requests
output1=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "San Francisco is a",
"max_tokens": 10,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"prompt": "Santa Clara is a",
"max_tokens": 10,
"temperature": 0
}')
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo ""
sleep 1
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""

View File

@ -0,0 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import uvloop
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.disagg_connector import run_disagg_connector
from vllm.utils import FlexibleArgumentParser
class ConnectSubcommand(CLISubcommand):
"""The `connect` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "connect"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
uvloop.run(run_disagg_connector(args))
def validate(self, args: argparse.Namespace) -> None:
validate_connect_parsed_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
connect_parser = subparsers.add_parser(
"connect",
help=
"Start the vLLM OpenAI Compatible API Server which connect to other"
"servers disaggreate prefill and decode",
usage="vllm connect [options]")
return make_connect_arg_parser(connect_parser)
def cmd_init() -> list[CLISubcommand]:
return [ConnectSubcommand()]
def make_connect_arg_parser(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument("--port",
type=int,
default=7001,
help="The fastapi server port default 7001")
# support ipc only now, support tcp later(with auth)
parser.add_argument(
"--protocol",
type=str,
choices=["ipc"],
default="ipc",
help="The zmq socket addr protocol IPC (Inter-Process Communication)")
# security concern only support ipc now
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The zmq ipc decode address")
return parser
def validate_connect_parsed_args(args: argparse.Namespace):
"""Quick checks for connect args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "connect":
return

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.cli_args import (LoRAParserAction,
PromptAdapterParserAction)
from vllm.entrypoints.openai.zmq_server import run_zmq_server
from vllm.utils import FlexibleArgumentParser
class DisaggSubcommand(CLISubcommand):
"""The `disagg` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "disagg"
super().__init__()
@staticmethod
def cmd(args: argparse.Namespace) -> None:
# The default value of `--model`
if not args.model_tag:
raise ValueError(
"With `vllm disagg`, you should provide the model as a "
"positional argument instead of via the `--model` option.")
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
asyncio.run(run_zmq_server(args))
def validate(self, args: argparse.Namespace) -> None:
validate_parsed_disagg_args(args)
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
disagg_parser = subparsers.add_parser(
"disagg",
help="Start the vLLM OpenAI Compatible API zmq server",
usage="vllm disagg <model_tag> [options]")
return make_disagg_arg_parser(disagg_parser)
def cmd_init() -> list[CLISubcommand]:
return [DisaggSubcommand()]
def make_disagg_arg_parser(
parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"model_tag",
type=str,
help=
"The model tag to use for the vLLM OpenAI Compatible API zmq server.")
parser.add_argument('--zmq-server-addr',
type=str,
required=True,
help='The address to serve the zmq server on.')
parser.add_argument(
"--disable-frontend-multiprocessing",
action="store_true",
help="If specified, will run the OpenAI frontend server in the same "
"process as the model serving engine.")
parser.add_argument(
"--return-tokens-as-token-ids",
action="store_true",
help="When ``--max-logprobs`` is specified, represents single tokens "
" as strings of the form 'token_id:{token_id}' so that tokens "
"that are not JSON-encodable can be identified.")
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
parser.add_argument(
"--lora-modules",
type=nullable_str,
default=None,
nargs='+',
action=LoRAParserAction,
help="LoRA module configurations in either 'name=path' format"
"or JSON format. "
"Example (old format): ``'name=path'`` "
"Example (new format): "
"``{\"name\": \"name\", \"path\": \"lora_path\", "
"\"base_model_name\": \"id\"}``")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
AsyncEngineArgs.add_cli_args(parser)
return parser
def validate_parsed_disagg_args(args: argparse.Namespace):
"""Quick checks for model disagg args that raise prior to loading."""
if hasattr(args, "subparser") and args.subparser != "disagg":
return
# Enable reasoning needs a reasoning parser to be valid
if args.enable_reasoning and not args.reasoning_parser:
raise TypeError("Error: --enable-reasoning requires "
"--reasoning-parser")

View File

@ -6,6 +6,8 @@ import signal
import sys
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.connect
import vllm.entrypoints.cli.disagg
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.serve
import vllm.version
@ -18,6 +20,8 @@ CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.disagg,
vllm.entrypoints.cli.connect,
]

View File

@ -3,44 +3,69 @@
import asyncio
import json
import signal
import sys
import traceback
import uuid
from asyncio import Queue
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from typing import Union
import uvicorn
import uvloop
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.entrypoints.openai.zmq_server import (CONTENT_TYPE_ERROR,
CONTENT_TYPE_JSON,
CONTENT_TYPE_STREAM)
from vllm.logger import init_logger
from vllm.utils import FlexibleArgumentParser
# default prefill and decode addr
time_out = 180
socket_prefill_num = 100
socket_decode_num = 100
context_type_json = "application/json"
context_type_error = "error"
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.disagg_connector')
TIME_OUT = 5
X_REQUEST_ID_KEY = "X-Request-Id"
# communication between output handlers and execute_task_async
request_queues: dict[str, asyncio.Queue]
async def log_stats(request_queues: dict[str, asyncio.Queue]):
while True:
logger.info("Running requests: %d", len(request_queues))
await asyncio.sleep(10)
# create async socket use ZMQ_DEALER
async def create_socket(url: str,
zmqctx: zmq.asyncio.Context) -> zmq.asyncio.Socket:
sock = zmqctx.socket(zmq.DEALER)
identity = f"connector-{uuid.uuid4()}"
sock.setsockopt(zmq.IDENTITY, identity.encode())
sock.connect(url)
logger.info("%s started at %s", identity, url)
return sock
@asynccontextmanager
async def lifespan(app: FastAPI):
# create socket pool with prefill and decode
logger.info("start create_socket_pool")
logger.info("start connect zmq server")
app.state.zmqctx = zmq.asyncio.Context()
app.state.sockets_prefill = await create_socket_pool(
app.state.prefill_addr, socket_prefill_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_prefill")
app.state.sockets_decode = await create_socket_pool(
app.state.decode_addr, socket_decode_num, zmqctx=app.state.zmqctx)
logger.info("success create_socket_pool sockets_decode")
app.state.prefill_socket = await create_socket(app.state.prefill_addr,
zmqctx=app.state.zmqctx)
logger.info("success create_socke sockets_prefill")
app.state.decode_socket = await create_socket(app.state.decode_addr,
zmqctx=app.state.zmqctx)
logger.info("success create_socket sockets_decode")
global request_queues
request_queues = {}
asyncio.create_task(prefill_handler(app.state.prefill_socket))
asyncio.create_task(decode_handler(app.state.decode_socket))
asyncio.create_task(log_stats(request_queues))
yield
## close zmq context
logger.info("shutdown disagg connector")
@ -51,59 +76,140 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# create async socket pool with num_sockets use ZMQ_DEALER
async def create_socket_pool(url: str, num_sockets: int,
zmqctx: zmq.asyncio.Context) -> Queue:
sockets: Queue[zmq.Socket] = Queue()
for i in range(num_sockets):
sock = zmqctx.socket(zmq.DEALER)
identity = f"worker-{i}-{uuid.uuid4()}"
sock.setsockopt(zmq.IDENTITY, identity.encode())
sock.connect(url)
logger.info("%s started at %s with queue size %s", identity, url,
sockets.qsize())
await sockets.put(sock)
return sockets
@app.post('/v1/completions')
async def completions(request: Request, background_tasks: BackgroundTasks):
try:
# Add the X-Request-Id header to the raw headers list
header = dict(request.headers)
request_id = header.get(X_REQUEST_ID_KEY)
queue = asyncio.Queue()
if request_id is None:
request_id = str(uuid.uuid4())
logger.debug("add X-Request-Id: %s", request_id)
header[X_REQUEST_ID_KEY] = request_id
logger.debug("X-Request-Id is: %s", request_id)
request_queues[request_id] = queue
request_data = await request.json()
logger.info("Received request_id: %s, request: %s, header: %s",
request_id, request_data, header)
original_max_tokens = request_data['max_tokens']
# change max_tokens = 1 to let it only do prefill
request_data['max_tokens'] = 1
# finish prefill
try:
prefill_response = await prefill(header, request_data)
if isinstance(prefill_response, JSONResponse):
return prefill_response
logger.debug("finish prefill start decode")
request_data['max_tokens'] = original_max_tokens
response = await decode(header, request_data)
logger.debug("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse({"error": {
"message": str(e)
}},
status_code=500)
return response
except Exception as e:
exc_info = sys.exc_info()
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))
response = JSONResponse({"error": {
"message": str(e)
}},
status_code=500)
return response
finally:
background_tasks.add_task(cleanup_request_id, request_id)
async def socket_recv_handler(socket: zmq.asyncio.Socket, scene: str):
while True:
try:
[request_id, contentType, reply] = await socket.recv_multipart()
contentType_str = contentType.decode()
reply_str = reply.decode()
request_id_str = request_id.decode()
logger.debug(
"%s socket received result contentType: %s, "
"request_id: %s, reply: %s", scene, contentType_str,
request_id_str, reply_str)
if request_id_str in request_queues:
request_queues[request_id_str].put_nowait(
(contentType_str, reply_str))
if "[DONE]" in reply_str:
logger.debug(
"%s socket received stop signal request_id: %s", scene,
request_id_str)
request_queues[request_id_str].put_nowait(
(contentType_str, None))
else:
logger.debug(
"%s socket received but request_id not found discard: %s",
scene, request_id_str)
except Exception as e:
logger.error(traceback.format_exc())
logger.error("%s handler error: %s", scene, e)
# prefill handler
async def prefill_handler(prefill_socket: zmq.asyncio.Socket):
await socket_recv_handler(prefill_socket, "prefill")
# decode handler
async def decode_handler(decode_socket: zmq.asyncio.Socket):
await socket_recv_handler(decode_socket, "decode")
# select a socket and execute task
async def execute_task_async(route: str, headers: dict, request: dict,
sockets: Queue):
sock: zmq.Socket = await sockets.get()
async def execute_task_async(headers: dict, request: dict,
socket: zmq.asyncio.Socket):
try:
request_id = headers.get(X_REQUEST_ID_KEY)
requestBody = json.dumps(request)
headersJson = json.dumps(headers)
logger.info("Sending requestBody: %s to %s with headers: %s",
requestBody, route, headersJson)
await asyncio.wait_for(sock.send_multipart(
[route.encode(),
headersJson.encode(),
requestBody.encode()]),
timeout=time_out)
logger.info("Sending requestBody: %s", requestBody)
socket.send_multipart([request_id.encode(), requestBody.encode()])
logger.debug("Sent end")
queue = request_queues[request_id]
while True:
logger.debug("Waiting for reply")
[contentType,
reply] = await asyncio.wait_for(sock.recv_multipart(),
timeout=time_out)
contentType_str = contentType.decode()
reply_str = reply.decode()
logger.debug("Received result: %s, %s", contentType_str, reply_str)
yield (contentType_str, reply_str)
if context_type_json == contentType_str:
logger.debug("Received %s message, return socket",
contentType_str)
(contentType,
reply) = await asyncio.wait_for(queue.get(), TIME_OUT)
logger.debug("Received result: %s, %s", contentType, reply)
if reply is None:
logger.debug("Received stop signal, request_id: %s",
request_id)
break
if "[DONE]" in reply_str:
logger.debug("Received stop signal, return socket")
yield (contentType, reply)
if contentType == CONTENT_TYPE_JSON:
logger.debug("Received %s message, request_id: %s",
contentType, request_id)
break
except asyncio.TimeoutError:
logger.error(traceback.format_exc())
logger.error("Timeout, return socket: %s",
sock.getsockopt(zmq.IDENTITY))
yield (context_type_error, "System Error")
yield (CONTENT_TYPE_ERROR, "System Error")
finally:
await sockets.put(sock)
logger.debug("request_id: %s, execute_task_async end", request_id)
async def prefill(header: dict,
original_request_data: dict) -> Union[JSONResponse, bool]:
logger.debug("start prefill")
generator = execute_task_async(header, original_request_data,
app.state.prefill_socket)
async for contentType, reply in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply)
if contentType == CONTENT_TYPE_ERROR:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
return True
async def generate_stream_response(fisrt_reply: str,
@ -113,78 +219,34 @@ async def generate_stream_response(fisrt_reply: str,
yield reply
async def prefill(route: str, header: dict, original_request_data: dict):
logger.info("start prefill")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_prefill)
async for contentType, reply in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply)
if context_type_error == contentType:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
return True
async def decode(route: str, header: dict, original_request_data: dict):
async def decode(
header: dict,
original_request_data: dict) -> Union[JSONResponse, StreamingResponse]:
logger.info("start decode")
generator = execute_task_async(route, header, original_request_data,
app.state.sockets_decode)
generator = execute_task_async(header, original_request_data,
app.state.decode_socket)
async for contentType, reply in generator:
logger.debug("contentType: %s, reply: %s", contentType, reply)
if context_type_error == contentType:
if contentType == CONTENT_TYPE_ERROR:
response = JSONResponse({"error": reply})
response.status_code = 500
return response
elif context_type_json == contentType:
elif contentType == CONTENT_TYPE_JSON:
return JSONResponse(reply)
else:
return StreamingResponse(generate_stream_response(
reply, generator),
media_type="text/event-stream")
media_type=CONTENT_TYPE_STREAM)
@app.post('/v1/completions')
async def chat_completions(request: Request):
try:
# Add the X-Request-Id header to the raw headers list
x_request_id = str(uuid.uuid4())
header = dict(request.headers)
if header.get("X-Request-Id") is None:
logger.info("add X-Request-Id: %s", x_request_id)
header["X-Request-Id"] = x_request_id
request_data = await request.json()
logger.info("Received request: %s header: %s", request_data, header)
original_max_tokens = request_data['max_tokens']
# change max_tokens = 1 to let it only do prefill
request_data['max_tokens'] = 1
route = "/v1/completions"
# finish prefill
try:
prefill_response = await prefill(route, header, request_data)
if isinstance(prefill_response, JSONResponse):
return prefill_response
logger.info("finish prefill start decode")
request_data['max_tokens'] = original_max_tokens
response = await decode(route, header, request_data)
logger.info("finish decode")
except Exception as e:
logger.error("Error occurred in disagg prefill proxy server, %s",
e)
response = JSONResponse({"error": {"message": str(e)}})
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
logger.error("Error occurred in disagg prefill proxy server")
logger.error(e)
logger.error("".join(traceback.format_exception(*exc_info)))
def cleanup_request_id(request_id: str):
if request_id in request_queues:
logger.info("del request_id: %s, decode finished", request_id)
del request_queues[request_id]
async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
async def run_disagg_connector(args, **uvicorn_kwargs):
logger.info("vLLM Disaggregate Connector start %s %s", args,
uvicorn_kwargs)
logger.info(args.prefill_addr)
@ -192,8 +254,9 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
app.state.prefill_addr = f"ipc://{args.prefill_addr}"
app.state.decode_addr = f"ipc://{args.decode_addr}"
logger.info(
"start connect prefill_addr: %s decode_addr: %s zmq server port: %s",
app.state.prefill_addr, app.state.decode_addr, app.state.port)
"start connect prefill_addr: %s decode_addr: %s "
"zmq server fastapi port: %s", app.state.prefill_addr,
app.state.decode_addr, app.state.port)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
@ -208,20 +271,22 @@ async def run_disagg_connector(args, **uvicorn_kwargs) -> None:
if __name__ == "__main__":
# NOTE(simon):
# This section should be in sync with vllm/scripts.py for CLI entrypoints.
parser = FlexibleArgumentParser(description="vLLM disagg zmq server.")
# This section should be sync with vllm/entrypoints/cli/connect.py for CLI
# entrypoints.
parser = FlexibleArgumentParser(description="vLLM disagg connect server.")
parser.add_argument("--port",
type=int,
default=8000,
help="The fastapi server port")
default=8001,
help="The fastapi server port default 8001")
# security concern only support ipc now
parser.add_argument("--prefill-addr",
type=str,
required=True,
help="The prefill address IP:PORT")
help="The zmq ipc prefill address")
parser.add_argument("--decode-addr",
type=str,
required=True,
help="The decode address IP:PORT")
help="The zmq ipc decode address")
args = parser.parse_args()

View File

@ -7,16 +7,12 @@ from http import HTTPStatus
from typing import Any, Optional
import uvicorn
import zmq
import zmq.asyncio
import zmq.devices
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.multiprocessing import MQEngineDeadError
from vllm.entrypoints.ssl import SSLCertRefresher
from vllm.entrypoints.openai.connect_worker import worker_routine
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
@ -79,43 +75,6 @@ async def serve_http(app: FastAPI,
return server.shutdown()
async def serve_zmq(arg, zmq_server_port: int, app: FastAPI) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_port: %d", arg,
zmq_server_port)
# different zmq context can't communicate use inproc
workers_addr = "ipc://workers"
clients_addr = f"ipc://127.0.0.1:{zmq_server_port}"
# Prepare our context and sockets
context = zmq.asyncio.Context.instance()
try:
tasks = [
asyncio.create_task(worker_routine(workers_addr, app, context, i))
for i in range(100)
]
logger.info("zmq tasks: %s", tasks)
# thread safety proxy create socket in the background:
# https://pyzmq.readthedocs.io/en/latest/api/zmq.devices.html#proxy-devices
thread_proxy = zmq.devices.ThreadProxy(zmq.ROUTER, zmq.DEALER)
# unlimited HWM
hwm_limit = 0
thread_proxy.bind_in(clients_addr)
thread_proxy.setsockopt_in(zmq.SNDHWM, hwm_limit)
thread_proxy.setsockopt_in(zmq.RCVHWM, hwm_limit)
thread_proxy.bind_out(workers_addr)
thread_proxy.setsockopt_out(zmq.SNDHWM, hwm_limit)
thread_proxy.setsockopt_out(zmq.RCVHWM, hwm_limit)
thread_proxy.start()
await asyncio.gather(*tasks)
except KeyboardInterrupt:
print("ZMQ Server interrupted")
except zmq.ZMQError as e:
print("ZMQError:", e)
finally:
# We never get here but clean up anyhow
context.destroy(linger=0)
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""

View File

@ -36,7 +36,7 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.launcher import serve_http, serve_zmq
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
validate_parsed_serve_args)

View File

@ -1,139 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import json
import tempfile
import traceback
import uuid
from typing import Optional
import httpx
import zmq
import zmq.asyncio
from fastapi import FastAPI, Request
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_tokenization import (
OpenAIServingTokenization)
from vllm.logger import init_logger
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
logger = init_logger('vllm.entrypoints.openai.connect_worker')
def base(app: FastAPI) -> OpenAIServing:
# Reuse the existing instance
return tokenization(app)
def models(app: FastAPI) -> OpenAIServingModels:
return app.state.openai_serving_models
def chat(app: FastAPI) -> Optional[OpenAIServingChat]:
return app.state.openai_serving_chat
def completion(app: FastAPI) -> Optional[OpenAIServingCompletion]:
return app.state.openai_serving_completion
def tokenization(app: FastAPI) -> OpenAIServingTokenization:
return app.state.openai_serving_tokenization
def bytes_to_headers(bytes_data: bytes) -> httpx.Headers:
headers_dict = json.loads(bytes_data.decode())
return httpx.Headers(headers_dict)
async def worker_routine(worker_addr: str, app: FastAPI,
context: zmq.asyncio.Context, i: int = 0):
"""Worker routine"""
try:
# Socket to talk to dispatcher
socket = context.socket(zmq.DEALER)
worker_identity = f"worker-{i}-{uuid.uuid4()}"
socket.setsockopt(zmq.IDENTITY, worker_identity.encode())
socket.connect(worker_addr)
logger.info("%s started at %s", worker_identity, worker_addr)
while True:
identity, url, header, body = await socket.recv_multipart()
logger.info("worker-%d Received request identity: [ %s ]",
i, identity.decode())
url_str = url.decode()
logger.info("worker-%d Received request url: [ %s ]",
i, url_str)
headers = bytes_to_headers(header)
logger.info("worker-%d Received request headers: [ %s ]",
i, headers)
body_json = json.loads(body.decode())
logger.info("worker-%d Received request body: [ %s ]",
i, body_json)
logger.info("worker-%d Calling OpenAI API", i)
completionRequest = CompletionRequest(**body_json)
createRequest = create_request(url_str, "POST", body_json, headers)
generator = await create_completion(app, completionRequest,
createRequest)
context_type_json = b"application/json"
if isinstance(generator, ErrorResponse):
content = generator.model_dump_json()
context_json = json.loads(content)
context_json.append("status_code", generator.code)
await socket.send_multipart([identity, context_type_json,
json.dumps(context_json).encode('utf-8')])
elif isinstance(generator, CompletionResponse):
await socket.send_multipart([identity,
context_type_json,
json.dumps(generator.model_dump()).encode('utf-8')])
else:
async for chunk in generator:
await socket.send_multipart([identity,
b"text/event-stream",
chunk.encode('utf-8')])
except Exception as e:
logger.error("Error in worker routine: %s worker-%d", e, i)
logger.error(traceback.format_exc())
async def create_completion(app: FastAPI, request: CompletionRequest,
raw_request: Request):
handler = completion(app)
logger.info("zmq request post: %s", request)
if handler is None:
return base(app).create_error_response(
message="The model does not support Completions API")
generator = await handler.create_completion(request, raw_request)
logger.info("zmq request end post: %s", generator)
return generator
def create_request(path: str, method: str, body: dict,
headers: httpx.Headers) -> Request:
scope = {
'type': 'http',
'http_version': '1.1',
'method': method,
'path': path,
'headers': list(headers.items()) if headers else [],
}
if body:
scope['body'] = json.dumps(body)
async def receive():
return {
'type': 'http.request',
'body': scope.get('body', b''),
}
async def send(message):
pass
return Request(scope, receive=receive, send=send)
if __name__ == "__main__":
print(bytes_to_headers(b'{"Content-Type": "application/json"}'))

View File

@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import json
import os
import signal
import traceback
from argparse import Namespace
import zmq
import zmq.asyncio
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.protocol import (CompletionRequest,
CompletionResponse,
ErrorResponse)
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.logger import init_logger
from vllm.utils import set_ulimit
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger('vllm.entrypoints.openai.zmq_server')
CONTENT_TYPE_JSON = "application/json"
CONTENT_TYPE_ERROR = "error"
CONTENT_TYPE_STREAM = "text/event-stream"
openai_serving_completion: OpenAIServingCompletion
openai_serving_models: OpenAIServingModels
async def log_stats(running_requests: set[asyncio.Task]):
while True:
logger.info("Running requests: %d", len(running_requests))
await asyncio.sleep(10)
def _cleanup_ipc_path(server_addr: str):
socket_path = server_addr.replace("ipc://", "")
logger.info("cleaning up local IPC socket file %s", socket_path)
if os.path.exists(socket_path):
os.remove(socket_path)
async def serve_zmq(arg) -> None:
"""Server routine"""
logger.info("zmq Server start arg: %s, zmq_server_addr: %s", arg,
arg.zmq_server_addr)
# different zmq context can't communicate use inproc
server_addr = f"ipc://{arg.zmq_server_addr}"
try:
# Prepare our context and sockets
context = zmq.asyncio.Context()
socket = context.socket(zmq.ROUTER)
# unlimited HWM
hwm_limit = 0
socket.bind(server_addr)
socket.setsockopt(zmq.SNDHWM, hwm_limit)
socket.setsockopt(zmq.RCVHWM, hwm_limit)
running_requests: set[asyncio.Task] = set()
logger.info("zmq Server started at %s", server_addr)
asyncio.create_task(log_stats(running_requests))
while True:
try:
logger.debug("zmq Server waiting for request")
# get new request from the client
identity, request_id, body = await socket.recv_multipart()
# launch request handler coroutine
task = asyncio.create_task(
worker_routine(identity, request_id, body, socket))
running_requests.add(task)
task.add_done_callback(running_requests.discard)
except zmq.ZMQError as e:
logger.error("ZMQError: %s", e)
break
except Exception as e:
logger.error("Unexpected error: %s", e)
break
except KeyboardInterrupt:
logger.info("KeyboardInterrupt received, exiting")
finally:
# Clean up resources
for task in running_requests:
task.cancel()
await asyncio.gather(*running_requests, return_exceptions=True)
socket.close()
context.destroy(linger=0)
_cleanup_ipc_path(server_addr)
async def run_zmq_server(args) -> None:
logger.info("vLLM zmq server version %s", VLLM_VERSION)
logger.info("args: %s", args)
# workaround to avoid footguns where uvicorn drops requests with too
# many concurrent requests active
set_ulimit()
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as engine_client:
model_config = await engine_client.get_model_config()
await init_state(engine_client, model_config, args)
logger.info("init_state successful")
await serve_zmq(args)
async def init_state(
engine_client: EngineClient,
model_config: ModelConfig,
args: Namespace,
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
base_model_paths = [
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
]
global openai_serving_models
openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
await openai_serving_models.init_static_loras()
global openai_serving_completion
openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
openai_serving_models,
request_logger=None,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
async def worker_routine(identity: bytes, request_id: bytes, body: bytes,
socket: zmq.asyncio.Socket):
"""Worker routine"""
try:
body_json = json.loads(body.decode())
request_id_str = request_id.decode()
logger.debug("receive request: %s from %s, request_id: %s", body_json,
identity.decode(), request_id_str)
completionRequest = CompletionRequest(**body_json)
generator = await create_completion(completionRequest, None)
content_type_json = CONTENT_TYPE_JSON.encode('utf-8')
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
if isinstance(generator, ErrorResponse):
content = json.loads(generator.model_dump_json())
content.update({"status_code": generator.code})
logger.debug("send ErrorResponse %s", json.dumps(content))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(content).encode('utf-8')
])
elif isinstance(generator, CompletionResponse):
logger.debug("send CompletionResponse %s",
json.dumps(generator.model_dump()))
await socket.send_multipart([
identity, request_id, content_type_json,
json.dumps(generator.model_dump()).encode('utf-8')
])
else:
async for chunk in generator:
logger.debug(
"send chunk identity: %s, request_id: %s, chunk: %s",
identity.decode(), request_id.decode(), chunk)
await socket.send_multipart([
identity, request_id, content_type_stream,
chunk.encode('utf-8')
])
except Exception as e:
logger.error("Error in worker routine: %s request_id: %s", e,
request_id_str)
logger.error(traceback.format_exc())
content_type_stream = CONTENT_TYPE_STREAM.encode('utf-8')
logger.debug("send ErrorResponse %s", str(e))
await socket.send_multipart([
identity, request_id,
CONTENT_TYPE_ERROR.encode('utf-8'),
str(e).encode('utf-8')
])
async def create_completion(request: CompletionRequest, raw_request: Request):
logger.debug("zmq request post: %s", request)
generator = await openai_serving_completion.create_completion(
request, raw_request)
logger.debug("zmq request end post")
return generator